mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-10 02:38:26 +00:00
fix(skills): enforce allowed-tools metadata (#2626)
* fix(skills): parse allowed-tools frontmatter * fix(skills): validate allowed-tools metadata * fix(skills): add shared allowed-tools policy * fix(subagents): enforce skill allowed-tools * fix(agent): enforce skill allowed-tools * refactor(skills): dedupe TypeVar and reuse cached enabled skills - Drop redundant module-level TypeVar in tool_policy; rely on PEP 695 syntax. - Expose get_cached_enabled_skills() and have the lead agent reuse it instead of synchronously rescanning skills on every request. * fix(agent): expose config-scoped skill cache * fix(subagents): pass filtered tools explicitly * fix(skills): clean allowed-tools policy feedback
This commit is contained in:
parent
2b0e62f679
commit
cef4224381
@ -20,6 +20,8 @@ from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
||||
from deerflow.config.app_config import AppConfig, get_app_config
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -308,6 +310,28 @@ def _build_middlewares(
|
||||
return middlewares
|
||||
|
||||
|
||||
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
|
||||
if is_bootstrap:
|
||||
return {"bootstrap"}
|
||||
if agent_config and agent_config.skills is not None:
|
||||
return set(agent_config.skills)
|
||||
return None
|
||||
|
||||
|
||||
def _load_enabled_skills_for_tool_policy(available_skills: set[str] | None, *, app_config: AppConfig) -> list[Skill]:
|
||||
try:
|
||||
from deerflow.agents.lead_agent.prompt import get_enabled_skills_for_config
|
||||
|
||||
skills = get_enabled_skills_for_config(app_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to load skills for allowed-tools policy")
|
||||
raise
|
||||
|
||||
if available_skills is None:
|
||||
return skills
|
||||
return [skill for skill in skills if skill.name in available_skills]
|
||||
|
||||
|
||||
def make_lead_agent(config: RunnableConfig):
|
||||
"""LangGraph graph factory; keep the signature compatible with LangGraph Server."""
|
||||
runtime_config = _get_runtime_config(config)
|
||||
@ -333,6 +357,7 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
agent_name = validate_agent_name(cfg.get("agent_name"))
|
||||
|
||||
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
|
||||
available_skills = _available_skill_names(agent_config, is_bootstrap)
|
||||
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
|
||||
agent_model_name = agent_config.model if agent_config and agent_config.model else None
|
||||
|
||||
@ -371,15 +396,18 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
"is_plan_mode": is_plan_mode,
|
||||
"subagent_enabled": subagent_enabled,
|
||||
"tool_groups": agent_config.tool_groups if agent_config else None,
|
||||
"available_skills": ["bootstrap"] if is_bootstrap else (agent_config.skills if agent_config and agent_config.skills is not None else None),
|
||||
"available_skills": sorted(available_skills) if available_skills is not None else None,
|
||||
}
|
||||
)
|
||||
|
||||
skills_for_tool_policy = _load_enabled_skills_for_tool_policy(available_skills, app_config=resolved_app_config)
|
||||
|
||||
if is_bootstrap:
|
||||
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
||||
tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent],
|
||||
tools=filter_tools_by_skill_allowed_tools(tools, skills_for_tool_policy),
|
||||
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled,
|
||||
@ -394,15 +422,10 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
# The default agent (no agent_name) does not see this tool.
|
||||
extra_tools = [update_agent] if agent_name else []
|
||||
# Default lead agent (unchanged behavior)
|
||||
tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
|
||||
tools=get_available_tools(
|
||||
model_name=model_name,
|
||||
groups=agent_config.tool_groups if agent_config else None,
|
||||
subagent_enabled=subagent_enabled,
|
||||
app_config=resolved_app_config,
|
||||
)
|
||||
+ extra_tools,
|
||||
tools=filter_tools_by_skill_allowed_tools(tools + extra_tools, skills_for_tool_policy),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled,
|
||||
|
||||
@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
|
||||
_enabled_skills_lock = threading.Lock()
|
||||
_enabled_skills_cache: list[Skill] | None = None
|
||||
_enabled_skills_by_config_cache: dict[int, tuple[object, list[Skill]]] = {}
|
||||
_enabled_skills_refresh_active = False
|
||||
_enabled_skills_refresh_version = 0
|
||||
_enabled_skills_refresh_event = threading.Event()
|
||||
@ -84,6 +85,7 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
|
||||
_get_cached_skills_prompt_section.cache_clear()
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_cache = None
|
||||
_enabled_skills_by_config_cache.clear()
|
||||
_enabled_skills_refresh_version += 1
|
||||
_enabled_skills_refresh_event.clear()
|
||||
if _enabled_skills_refresh_active:
|
||||
@ -107,6 +109,15 @@ def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_W
|
||||
|
||||
|
||||
def _get_enabled_skills():
|
||||
return get_cached_enabled_skills()
|
||||
|
||||
|
||||
def get_cached_enabled_skills() -> list[Skill]:
|
||||
"""Return the cached enabled-skills list, kicking off a background refresh on miss.
|
||||
|
||||
Safe to call from request paths: never blocks on disk I/O. Returns an empty
|
||||
list on cache miss; the next call will see the warmed result.
|
||||
"""
|
||||
with _enabled_skills_lock:
|
||||
cached = _enabled_skills_cache
|
||||
|
||||
@ -117,17 +128,29 @@ def _get_enabled_skills():
|
||||
return []
|
||||
|
||||
|
||||
def _get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
|
||||
def get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
|
||||
"""Return enabled skills using the caller's config source.
|
||||
|
||||
When a concrete ``app_config`` is supplied, bypass the global enabled-skills
|
||||
cache so the skill list and skill paths are resolved from the same config
|
||||
object. This keeps request-scoped config injection consistent even while the
|
||||
release branch still supports global fallback paths.
|
||||
When a concrete ``app_config`` is supplied, cache the loaded skills by that
|
||||
config object's identity so request-scoped config injection still resolves
|
||||
skill paths from the matching config without rescanning storage on every
|
||||
agent factory call.
|
||||
"""
|
||||
if app_config is None:
|
||||
return _get_enabled_skills()
|
||||
return list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
|
||||
|
||||
cache_key = id(app_config)
|
||||
with _enabled_skills_lock:
|
||||
cached = _enabled_skills_by_config_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
cached_config, cached_skills = cached
|
||||
if cached_config is app_config:
|
||||
return list(cached_skills)
|
||||
|
||||
skills = list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_by_config_cache[cache_key] = (app_config, skills)
|
||||
return list(skills)
|
||||
|
||||
|
||||
def _skill_mutability_label(category: SkillCategory | str) -> str:
|
||||
@ -605,7 +628,7 @@ You have access to skills that provide optimized workflows for specific tasks. E
|
||||
|
||||
def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
|
||||
"""Generate the skills prompt section with available skills list."""
|
||||
skills = _get_enabled_skills_for_config(app_config)
|
||||
skills = get_enabled_skills_for_config(app_config)
|
||||
|
||||
if app_config is None:
|
||||
try:
|
||||
|
||||
@ -9,6 +9,29 @@ from .types import SKILL_MD_FILE, Skill, SkillCategory
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_allowed_tools(raw: object, skill_file: Path) -> list[str] | None:
|
||||
"""Parse the optional allowed-tools frontmatter field.
|
||||
|
||||
Returns None when the field is omitted. Returns a list when the field is a
|
||||
YAML sequence of strings, including an empty list for explicit no-tool
|
||||
skills. Raises ValueError for malformed values.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
if not isinstance(raw, list):
|
||||
raise ValueError(f"allowed-tools in {skill_file} must be a list of strings")
|
||||
|
||||
allowed_tools: list[str] = []
|
||||
for item in raw:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError(f"allowed-tools in {skill_file} must contain only strings")
|
||||
tool_name = item.strip()
|
||||
if not tool_name:
|
||||
raise ValueError(f"allowed-tools in {skill_file} cannot contain empty tool names")
|
||||
allowed_tools.append(tool_name)
|
||||
return allowed_tools
|
||||
|
||||
|
||||
def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: Path | None = None) -> Skill | None:
|
||||
"""Parse a SKILL.md file and extract metadata.
|
||||
|
||||
@ -64,6 +87,12 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
|
||||
if license_text is not None:
|
||||
license_text = str(license_text).strip() or None
|
||||
|
||||
try:
|
||||
allowed_tools = parse_allowed_tools(metadata.get("allowed-tools"), skill_file)
|
||||
except ValueError as exc:
|
||||
logger.error("Invalid allowed-tools in %s: %s", skill_file, exc)
|
||||
return None
|
||||
|
||||
return Skill(
|
||||
name=name,
|
||||
description=description,
|
||||
@ -72,6 +101,7 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
|
||||
skill_file=skill_file,
|
||||
relative_path=relative_path or Path(skill_file.parent.name),
|
||||
category=category,
|
||||
allowed_tools=allowed_tools,
|
||||
enabled=True, # Actual state comes from the extensions config file.
|
||||
)
|
||||
|
||||
|
||||
44
backend/packages/harness/deerflow/skills/tool_policy.py
Normal file
44
backend/packages/harness/deerflow/skills/tool_policy.py
Normal file
@ -0,0 +1,44 @@
|
||||
import logging
|
||||
from typing import Protocol
|
||||
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NamedTool(Protocol):
|
||||
name: str
|
||||
|
||||
|
||||
def allowed_tool_names_for_skills(skills: list[Skill]) -> set[str] | None:
|
||||
"""Return the union of explicit skill allowed-tools declarations.
|
||||
|
||||
None means legacy allow-all behavior. It is returned only when no loaded
|
||||
skill declares allowed-tools. Once any skill declares the field, legacy
|
||||
skills without the field contribute no tools instead of disabling the
|
||||
explicit restrictions from other skills.
|
||||
"""
|
||||
if not skills:
|
||||
return None
|
||||
|
||||
allowed: set[str] = set()
|
||||
has_explicit_declaration = False
|
||||
for skill in skills:
|
||||
if skill.allowed_tools is None:
|
||||
continue
|
||||
has_explicit_declaration = True
|
||||
if not skill.allowed_tools:
|
||||
logger.info("Skill %s declared empty allowed-tools", skill.name)
|
||||
allowed.update(skill.allowed_tools)
|
||||
|
||||
if not has_explicit_declaration:
|
||||
return None
|
||||
return allowed
|
||||
|
||||
|
||||
def filter_tools_by_skill_allowed_tools[ToolT: NamedTool](tools: list[ToolT], skills: list[Skill]) -> list[ToolT]:
|
||||
allowed = allowed_tool_names_for_skills(skills)
|
||||
if allowed is None:
|
||||
return tools
|
||||
|
||||
return [tool for tool in tools if tool.name in allowed]
|
||||
@ -27,6 +27,7 @@ class Skill:
|
||||
skill_file: Path
|
||||
relative_path: Path # Relative path from category root to skill directory
|
||||
category: SkillCategory # 'public' or 'custom'
|
||||
allowed_tools: list[str] | None = None
|
||||
enabled: bool = False # Whether this skill is enabled
|
||||
|
||||
@property
|
||||
|
||||
@ -8,6 +8,7 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from deerflow.skills.parser import parse_allowed_tools
|
||||
from deerflow.skills.types import SKILL_MD_FILE
|
||||
|
||||
# Allowed properties in SKILL.md frontmatter
|
||||
@ -84,4 +85,9 @@ def _validate_skill_frontmatter(skill_dir: Path) -> tuple[bool, str, str | None]
|
||||
if len(description) > 1024:
|
||||
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters.", None
|
||||
|
||||
try:
|
||||
parse_allowed_tools(frontmatter.get("allowed-tools"), skill_md)
|
||||
except ValueError as e:
|
||||
return False, str(e).replace(str(skill_md), SKILL_MD_FILE), None
|
||||
|
||||
return True, "Skill is valid!", name
|
||||
|
||||
@ -23,6 +23,8 @@ from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadSt
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -260,16 +262,16 @@ class SubagentExecutor:
|
||||
# Generate trace_id if not provided (for top-level calls)
|
||||
self.trace_id = trace_id or str(uuid.uuid4())[:8]
|
||||
|
||||
# Filter tools based on config
|
||||
self.tools = _filter_tools(
|
||||
self._base_tools = _filter_tools(
|
||||
tools,
|
||||
config.tools,
|
||||
config.disallowed_tools,
|
||||
)
|
||||
self.tools = self._base_tools
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] SubagentExecutor initialized: {config.name} with {len(self.tools)} tools")
|
||||
|
||||
def _create_agent(self):
|
||||
def _create_agent(self, tools: list[BaseTool] | None = None):
|
||||
"""Create the agent instance."""
|
||||
app_config = self.app_config or get_app_config()
|
||||
if self.model_name is None:
|
||||
@ -283,26 +285,14 @@ class SubagentExecutor:
|
||||
|
||||
return create_agent(
|
||||
model=model,
|
||||
tools=self.tools,
|
||||
tools=tools if tools is not None else self.tools,
|
||||
middleware=middlewares,
|
||||
system_prompt=self.config.system_prompt,
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
async def _load_skill_messages(self) -> list[SystemMessage]:
|
||||
"""Load skill content as conversation items based on config.skills.
|
||||
|
||||
Aligned with Codex's pattern: each subagent loads its own skills
|
||||
per-session and injects them as conversation items (developer messages),
|
||||
not as system prompt text. The config.skills whitelist controls which
|
||||
skills are loaded:
|
||||
- None: load all enabled skills
|
||||
- []: no skills
|
||||
- ["skill-a", "skill-b"]: only these skills
|
||||
|
||||
Returns:
|
||||
List of SystemMessages containing skill content.
|
||||
"""
|
||||
async def _load_skills(self) -> list[Skill]:
|
||||
"""Load enabled skill metadata based on config.skills."""
|
||||
if self.config.skills is not None and len(self.config.skills) == 0:
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} skills=[] — skipping skill loading")
|
||||
return []
|
||||
@ -316,8 +306,8 @@ class SubagentExecutor:
|
||||
all_skills = await asyncio.to_thread(storage.load_skills, enabled_only=True)
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded {len(all_skills)} enabled skills from disk")
|
||||
except Exception:
|
||||
logger.warning(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}", exc_info=True)
|
||||
return []
|
||||
logger.exception(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}")
|
||||
raise
|
||||
|
||||
if not all_skills:
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} no enabled skills found")
|
||||
@ -326,10 +316,26 @@ class SubagentExecutor:
|
||||
# Filter by config.skills whitelist
|
||||
if self.config.skills is not None:
|
||||
allowed = set(self.config.skills)
|
||||
skills = [s for s in all_skills if s.name in allowed]
|
||||
else:
|
||||
skills = all_skills
|
||||
return [s for s in all_skills if s.name in allowed]
|
||||
return all_skills
|
||||
|
||||
def _apply_skill_allowed_tools(self, skills: list[Skill]) -> list[BaseTool]:
|
||||
return filter_tools_by_skill_allowed_tools(self._base_tools, skills)
|
||||
|
||||
async def _load_skill_messages(self, skills: list[Skill]) -> list[SystemMessage]:
|
||||
"""Load skill content as conversation items based on config.skills.
|
||||
|
||||
Aligned with Codex's pattern: each subagent loads its own skills
|
||||
per-session and injects them as conversation items (developer messages),
|
||||
not as system prompt text. The config.skills whitelist controls which
|
||||
skills are loaded:
|
||||
- None: load all enabled skills
|
||||
- []: no skills
|
||||
- ["skill-a", "skill-b"]: only these skills
|
||||
|
||||
Returns:
|
||||
List of SystemMessages containing skill content.
|
||||
"""
|
||||
if not skills:
|
||||
return []
|
||||
|
||||
@ -347,19 +353,21 @@ class SubagentExecutor:
|
||||
|
||||
return messages
|
||||
|
||||
async def _build_initial_state(self, task: str) -> dict[str, Any]:
|
||||
async def _build_initial_state(self, task: str) -> tuple[dict[str, Any], list[BaseTool]]:
|
||||
"""Build the initial state for agent execution.
|
||||
|
||||
Args:
|
||||
task: The task description.
|
||||
|
||||
Returns:
|
||||
Initial state dictionary.
|
||||
Initial state dictionary and tools filtered by loaded skill metadata.
|
||||
"""
|
||||
# Load skills as conversation items (Codex pattern)
|
||||
skill_messages = await self._load_skill_messages()
|
||||
skills = await self._load_skills()
|
||||
filtered_tools = self._apply_skill_allowed_tools(skills)
|
||||
skill_messages = await self._load_skill_messages(skills)
|
||||
|
||||
messages: list = []
|
||||
messages: list[Any] = []
|
||||
# Skill content injected as developer/system messages before the task
|
||||
messages.extend(skill_messages)
|
||||
# Then the actual task
|
||||
@ -375,7 +383,7 @@ class SubagentExecutor:
|
||||
if self.thread_data is not None:
|
||||
state["thread_data"] = self.thread_data
|
||||
|
||||
return state
|
||||
return state, filtered_tools
|
||||
|
||||
async def _aexecute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
||||
"""Execute a task asynchronously.
|
||||
@ -405,8 +413,8 @@ class SubagentExecutor:
|
||||
result.ai_messages = ai_messages
|
||||
|
||||
try:
|
||||
agent = self._create_agent()
|
||||
state = await self._build_initial_state(task)
|
||||
state, filtered_tools = await self._build_initial_state(task)
|
||||
agent = self._create_agent(filtered_tools)
|
||||
|
||||
# Build config with thread_id for sandbox access and recursion limit
|
||||
run_config: RunnableConfig = {
|
||||
|
||||
@ -1,17 +1,20 @@
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.subagents_config import CustomSubagentConfig, SubagentsAppConfig
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.skills.types import Skill, SkillCategory
|
||||
|
||||
|
||||
def _set_skills_cache_state(*, skills=None, active=False, version=0):
|
||||
prompt_module._get_cached_skills_prompt_section.cache_clear()
|
||||
with prompt_module._enabled_skills_lock:
|
||||
prompt_module._enabled_skills_cache = skills
|
||||
prompt_module._enabled_skills_by_config_cache.clear()
|
||||
prompt_module._enabled_skills_refresh_active = active
|
||||
prompt_module._enabled_skills_refresh_version = version
|
||||
prompt_module._enabled_skills_refresh_event.clear()
|
||||
@ -232,7 +235,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category="custom",
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
@ -252,6 +255,58 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_explicit_config_enabled_skills_are_cached_by_config_identity(monkeypatch, tmp_path):
|
||||
def make_skill(name: str) -> Skill:
|
||||
skill_dir = tmp_path / name
|
||||
return Skill(
|
||||
name=name,
|
||||
description=f"Description for {name}",
|
||||
license="MIT",
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
config = cast(
|
||||
AppConfig,
|
||||
cast(
|
||||
object,
|
||||
SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
),
|
||||
),
|
||||
)
|
||||
load_count = 0
|
||||
|
||||
def fake_get_or_new_skill_storage(**kwargs):
|
||||
nonlocal load_count
|
||||
assert kwargs == {"app_config": config}
|
||||
|
||||
def load_skills(*, enabled_only):
|
||||
nonlocal load_count
|
||||
load_count += 1
|
||||
assert enabled_only is True
|
||||
return [make_skill("cached-skill")]
|
||||
|
||||
return SimpleNamespace(load_skills=load_skills)
|
||||
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fake_get_or_new_skill_storage)
|
||||
_set_skills_cache_state()
|
||||
|
||||
try:
|
||||
first = prompt_module.get_skills_prompt_section(app_config=config)
|
||||
second = prompt_module.get_skills_prompt_section(app_config=config)
|
||||
|
||||
assert "cached-skill" in first
|
||||
assert "cached-skill" in second
|
||||
assert load_count == 1
|
||||
finally:
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_path):
|
||||
started = threading.Event()
|
||||
release = threading.Event()
|
||||
@ -269,7 +324,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category="custom",
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@ -6,7 +6,12 @@ from deerflow.config.agents_config import AgentConfig
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
|
||||
def _make_skill(name: str) -> Skill:
|
||||
class NamedTool:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
def _make_skill(name: str, allowed_tools: list[str] | None = None) -> Skill:
|
||||
return Skill(
|
||||
name=name,
|
||||
description=f"Description for {name}",
|
||||
@ -15,6 +20,7 @@ def _make_skill(name: str) -> Skill:
|
||||
skill_file=Path(f"/tmp/{name}/SKILL.md"),
|
||||
relative_path=Path(name),
|
||||
category="public",
|
||||
allowed_tools=allowed_tools,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
@ -132,6 +138,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
|
||||
@ -164,3 +171,106 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
assert captured_skills[-1] == {"skill1"}
|
||||
|
||||
|
||||
def test_make_lead_agent_filters_tools_from_available_skills(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted", "legacy"]))
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("restricted", ["read_file"]), _make_skill("legacy", None)])
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
|
||||
|
||||
|
||||
def test_make_lead_agent_all_legacy_skills_preserve_all_tools(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("legacy", None)])
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["bash", "read_file", "update_agent"]
|
||||
|
||||
|
||||
def test_make_lead_agent_enforces_allowed_tools_when_skill_cache_is_cold(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
mock_storage = SimpleNamespace(load_skills=lambda *, enabled_only: [_make_skill("restricted", ["read_file"])])
|
||||
|
||||
with prompt_module._enabled_skills_lock:
|
||||
prompt_module._enabled_skills_cache = None
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None, **kwargs: mock_storage)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
|
||||
|
||||
|
||||
def test_make_lead_agent_fails_closed_when_skill_policy_load_fails(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
create_agent_mock = MagicMock()
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", create_agent_mock)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
|
||||
def fail_storage(*args, **kwargs):
|
||||
raise RuntimeError("skill storage unavailable")
|
||||
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fail_storage)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
with pytest.raises(RuntimeError, match="skill storage unavailable"):
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
create_agent_mock.assert_not_called()
|
||||
|
||||
@ -86,6 +86,33 @@ def test_parse_license_field(tmp_path):
|
||||
assert skill.license == "MIT"
|
||||
|
||||
|
||||
def test_parse_missing_allowed_tools_returns_none(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test")
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is not None
|
||||
assert skill.allowed_tools is None
|
||||
|
||||
|
||||
def test_parse_allowed_tools_list(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, 'name: my-skill\ndescription: Test\nallowed-tools: ["bash", "read_file"]')
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is not None
|
||||
assert skill.allowed_tools == ["bash", "read_file"]
|
||||
|
||||
|
||||
def test_parse_empty_allowed_tools_list(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test\nallowed-tools: []")
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is not None
|
||||
assert skill.allowed_tools == []
|
||||
|
||||
|
||||
def test_parse_invalid_allowed_tools_returns_none(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test\nallowed-tools: bash")
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is None
|
||||
|
||||
|
||||
def test_parse_missing_name_returns_none(tmp_path):
|
||||
"""Skills missing a name field are rejected."""
|
||||
skill_file = _write_skill(tmp_path, "description: A test skill")
|
||||
|
||||
@ -30,13 +30,47 @@ class TestValidateSkillFrontmatter:
|
||||
def test_valid_with_all_allowed_fields(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nlicense: MIT\nversion: '1.0'\nauthor: test\n---\n\nBody\n",
|
||||
"---\nname: my-skill\ndescription: A skill\nlicense: MIT\nversion: '1.0'\nauthor: test\nallowed-tools: [bash, read_file]\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is True
|
||||
assert msg == "Skill is valid!"
|
||||
assert name == "my-skill"
|
||||
|
||||
def test_allows_empty_allowed_tools(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nallowed-tools: []\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is True
|
||||
assert msg == "Skill is valid!"
|
||||
assert name == "my-skill"
|
||||
|
||||
def test_rejects_allowed_tools_string(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nallowed-tools: bash\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is False
|
||||
assert "allowed-tools" in msg
|
||||
assert str(tmp_path) not in msg
|
||||
assert "SKILL.md" in msg
|
||||
assert name is None
|
||||
|
||||
def test_rejects_allowed_tools_non_string_entry(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nallowed-tools: [bash, 1]\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is False
|
||||
assert "allowed-tools" in msg
|
||||
assert str(tmp_path) not in msg
|
||||
assert "SKILL.md" in msg
|
||||
assert name is None
|
||||
|
||||
def test_missing_skill_md(self, tmp_path):
|
||||
valid, msg, name = _validate_skill_frontmatter(tmp_path)
|
||||
assert valid is False
|
||||
|
||||
@ -17,11 +17,14 @@ import asyncio
|
||||
import sys
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
# Module names that need to be mocked to break circular imports
|
||||
_MOCKED_MODULE_NAMES = [
|
||||
"deerflow.agents",
|
||||
@ -32,14 +35,15 @@ _MOCKED_MODULE_NAMES = [
|
||||
"deerflow.sandbox.middleware",
|
||||
"deerflow.sandbox.security",
|
||||
"deerflow.models",
|
||||
"deerflow.skills.storage",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_executor_classes():
|
||||
"""Set up mocked modules and import real executor classes.
|
||||
|
||||
This fixture runs once per session and yields the executor classes.
|
||||
This fixture runs once per test and yields the executor classes.
|
||||
It handles module cleanup to avoid affecting other test files.
|
||||
"""
|
||||
# Save original modules
|
||||
@ -53,6 +57,9 @@ def _setup_executor_classes():
|
||||
# Set up mocks
|
||||
for name in _MOCKED_MODULE_NAMES:
|
||||
sys.modules[name] = MagicMock()
|
||||
storage_module = ModuleType("deerflow.skills.storage")
|
||||
storage_module.get_or_new_skill_storage = lambda **kwargs: SimpleNamespace(load_skills=lambda *, enabled_only: [])
|
||||
sys.modules["deerflow.skills.storage"] = storage_module
|
||||
|
||||
# Import real classes inside fixture
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
@ -117,6 +124,26 @@ class MockAIMessage:
|
||||
return msg
|
||||
|
||||
|
||||
class NamedTool:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
def _skill(name: str, allowed_tools: list[str] | None) -> Skill:
|
||||
skill_dir = Path(f"/tmp/{name}")
|
||||
return Skill(
|
||||
name=name,
|
||||
description=f"{name} skill",
|
||||
license=None,
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=Path(name),
|
||||
category="custom",
|
||||
allowed_tools=allowed_tools,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
async def async_iterator(items):
|
||||
"""Helper to create an async iterator from a list."""
|
||||
for item in items:
|
||||
@ -288,7 +315,7 @@ class TestAgentConstruction:
|
||||
captured["app_config"] = app_config
|
||||
return SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="demo-skill", skill_file=skill_file)])
|
||||
|
||||
monkeypatch.setattr("deerflow.skills.storage.get_or_new_skill_storage", fake_get_or_new_skill_storage)
|
||||
monkeypatch.setattr(sys.modules["deerflow.skills.storage"], "get_or_new_skill_storage", fake_get_or_new_skill_storage)
|
||||
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
@ -297,7 +324,8 @@ class TestAgentConstruction:
|
||||
thread_id="test-thread",
|
||||
)
|
||||
|
||||
messages = await executor._load_skill_messages()
|
||||
skills = await executor._load_skills()
|
||||
messages = await executor._load_skill_messages(skills)
|
||||
|
||||
assert captured["app_config"] is app_config
|
||||
assert len(messages) == 1
|
||||
@ -487,6 +515,115 @@ class TestAsyncExecutionPath:
|
||||
assert "Task" in result.result
|
||||
|
||||
|
||||
class TestSkillAllowedTools:
|
||||
@pytest.mark.anyio
|
||||
async def test_skill_allowed_tools_union_filters_agent_tools(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("a", ["bash"]), _skill("b", ["read_file"])]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
create_agent_mock.assert_called_once()
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash", "read_file"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_all_missing_allowed_tools_preserves_legacy_allow_all(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("legacy-a", None), _skill("legacy-b", None)]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash", "read_file", "web_search"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_mixed_missing_allowed_tools_does_not_disable_explicit_restrictions(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("legacy", None), _skill("restricted", ["bash"])]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_mixed_missing_allowed_tools_order_does_not_disable_explicit_restrictions(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("restricted", ["bash"]), _skill("legacy", None)]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_allowed_tools_contributes_no_tools(self, classes, base_config, mock_agent, msg, caplog):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("empty", []), _skill("reader", ["read_file"])]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock, caplog.at_level("INFO"):
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["read_file"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
assert "declared empty allowed-tools" in caplog.text
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skill_load_failure_fails_without_creating_agent(self, classes, base_config, mock_agent):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
executor = SubagentExecutor(config=base_config, tools=[NamedTool("bash")], thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
raise RuntimeError("skill storage unavailable")
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
result = await executor._aexecute("Task")
|
||||
|
||||
assert result.status == classes["SubagentStatus"].FAILED
|
||||
assert result.error == "skill storage unavailable"
|
||||
create_agent_mock.assert_not_called()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Sync Execution Path Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user