578 lines
22 KiB
Python
Executable File

"""Agent-specific configuration dataclasses."""
from dataclasses import dataclass, field, replace
from typing import Any, Dict, Iterable, List, Mapping, Sequence
try: # pragma: no cover - Python < 3.11 lacks BaseExceptionGroup
from builtins import BaseExceptionGroup as _BASE_EXCEPTION_GROUP_TYPE # type: ignore[attr-defined]
except ImportError: # pragma: no cover
_BASE_EXCEPTION_GROUP_TYPE = None # type: ignore[assignment]
from entity.enums import AgentInputMode
from schema_registry import iter_model_provider_schemas
from utils.strs import titleize
from entity.configs.base import (
BaseConfig,
ConfigError,
ConfigFieldSpec,
EnumOption,
optional_bool,
optional_dict,
optional_str,
require_mapping,
require_str,
extend_path,
)
from .memory import MemoryAttachmentConfig
from .skills import AgentSkillsConfig
from .thinking import ThinkingConfig
from entity.configs.node.tooling import ToolingConfig
DEFAULT_RETRYABLE_STATUS_CODES = [408, 409, 425, 429, 500, 502, 503, 504]
DEFAULT_RETRYABLE_EXCEPTION_TYPES = [
"RateLimitError",
"APITimeoutError",
"APIError",
"APIConnectionError",
"ServiceUnavailableError",
"TimeoutError",
"InternalServerError",
"RemoteProtocolError",
"TransportError",
"ConnectError",
"ConnectTimeout",
"ReadError",
"ReadTimeout",
]
DEFAULT_RETRYABLE_MESSAGE_SUBSTRINGS = [
"rate limit",
"temporarily unavailable",
"timeout",
"server disconnected",
"connection reset",
]
def _coerce_float(value: Any, *, field_path: str, minimum: float = 0.0) -> float:
if isinstance(value, (int, float)):
coerced = float(value)
else:
raise ConfigError("expected number", field_path)
if coerced < minimum:
raise ConfigError(f"value must be >= {minimum}", field_path)
return coerced
def _coerce_positive_int(value: Any, *, field_path: str, minimum: int = 1) -> int:
if isinstance(value, bool):
raise ConfigError("expected integer", field_path)
if isinstance(value, int):
coerced = value
else:
raise ConfigError("expected integer", field_path)
if coerced < minimum:
raise ConfigError(f"value must be >= {minimum}", field_path)
return coerced
def _coerce_str_list(value: Any, *, field_path: str) -> List[str]:
if value is None:
return []
if not isinstance(value, Sequence) or isinstance(value, (str, bytes)):
raise ConfigError("expected list of strings", field_path)
result: List[str] = []
for idx, item in enumerate(value):
if not isinstance(item, str):
raise ConfigError("expected list of strings", f"{field_path}[{idx}]")
result.append(item.strip())
return result
def _coerce_int_list(value: Any, *, field_path: str) -> List[int]:
if value is None:
return []
if not isinstance(value, Sequence) or isinstance(value, (str, bytes)):
raise ConfigError("expected list of integers", field_path)
ints: List[int] = []
for idx, item in enumerate(value):
if isinstance(item, bool) or not isinstance(item, int):
raise ConfigError("expected list of integers", f"{field_path}[{idx}]")
ints.append(item)
return ints
@dataclass
class AgentRetryConfig(BaseConfig):
enabled: bool = True
max_attempts: int = 5
min_wait_seconds: float = 1.0
max_wait_seconds: float = 6.0
retry_on_status_codes: List[int] = field(default_factory=lambda: list(DEFAULT_RETRYABLE_STATUS_CODES))
retry_on_exception_types: List[str] = field(default_factory=lambda: [name.lower() for name in DEFAULT_RETRYABLE_EXCEPTION_TYPES])
non_retry_exception_types: List[str] = field(default_factory=list)
retry_on_error_substrings: List[str] = field(default_factory=lambda: list(DEFAULT_RETRYABLE_MESSAGE_SUBSTRINGS))
FIELD_SPECS = {
"enabled": ConfigFieldSpec(
name="enabled",
display_name="Enable Retry",
type_hint="bool",
required=False,
default=True,
description="Toggle automatic retry for provider calls",
),
"max_attempts": ConfigFieldSpec(
name="max_attempts",
display_name="Max Attempts",
type_hint="int",
required=False,
default=5,
description="Maximum number of total attempts (initial call + retries)",
),
"min_wait_seconds": ConfigFieldSpec(
name="min_wait_seconds",
display_name="Min Wait Seconds",
type_hint="float",
required=False,
default=1.0,
description="Minimum backoff wait before retry",
advance=True,
),
"max_wait_seconds": ConfigFieldSpec(
name="max_wait_seconds",
display_name="Max Wait Seconds",
type_hint="float",
required=False,
default=6.0,
description="Maximum backoff wait before retry",
advance=True,
),
"retry_on_status_codes": ConfigFieldSpec(
name="retry_on_status_codes",
display_name="Retryable Status Codes",
type_hint="list[int]",
required=False,
description="HTTP status codes that should trigger a retry",
advance=True,
),
"retry_on_exception_types": ConfigFieldSpec(
name="retry_on_exception_types",
display_name="Retryable Exception Types",
type_hint="list[str]",
required=False,
description="Exception class names (case-insensitive) that should trigger retries",
advance=True,
),
"non_retry_exception_types": ConfigFieldSpec(
name="non_retry_exception_types",
display_name="Non-Retryable Exception Types",
type_hint="list[str]",
required=False,
description="Exception class names (case-insensitive) that should never retry",
advance=True,
),
"retry_on_error_substrings": ConfigFieldSpec(
name="retry_on_error_substrings",
display_name="Retryable Message Substrings",
type_hint="list[str]",
required=False,
description="Substring matches within exception messages that enable retry",
advance=True,
),
}
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "AgentRetryConfig":
mapping = require_mapping(data, path)
enabled = optional_bool(mapping, "enabled", path, default=True)
if enabled is None:
enabled = True
max_attempts = _coerce_positive_int(mapping.get("max_attempts", 5), field_path=extend_path(path, "max_attempts"))
min_wait = _coerce_float(mapping.get("min_wait_seconds", 1.0), field_path=extend_path(path, "min_wait_seconds"), minimum=0.0)
max_wait = _coerce_float(mapping.get("max_wait_seconds", 6.0), field_path=extend_path(path, "max_wait_seconds"), minimum=0.0)
if max_wait < min_wait:
raise ConfigError("max_wait_seconds must be >= min_wait_seconds", extend_path(path, "max_wait_seconds"))
status_codes = mapping.get("retry_on_status_codes")
if status_codes is None:
retry_status_codes = list(DEFAULT_RETRYABLE_STATUS_CODES)
else:
retry_status_codes = _coerce_int_list(status_codes, field_path=extend_path(path, "retry_on_status_codes"))
retry_types_raw = mapping.get("retry_on_exception_types")
if retry_types_raw is None:
retry_types = [name.lower() for name in DEFAULT_RETRYABLE_EXCEPTION_TYPES]
else:
retry_types = [value.lower() for value in _coerce_str_list(retry_types_raw, field_path=extend_path(path, "retry_on_exception_types")) if value]
non_retry_types = [value.lower() for value in _coerce_str_list(mapping.get("non_retry_exception_types"), field_path=extend_path(path, "non_retry_exception_types")) if value]
retry_substrings_raw = mapping.get("retry_on_error_substrings")
if retry_substrings_raw is None:
retry_substrings = list(DEFAULT_RETRYABLE_MESSAGE_SUBSTRINGS)
else:
retry_substrings = [
value.lower()
for value in _coerce_str_list(
retry_substrings_raw,
field_path=extend_path(path, "retry_on_error_substrings"),
)
if value
]
return cls(
enabled=enabled,
max_attempts=max_attempts,
min_wait_seconds=min_wait,
max_wait_seconds=max_wait,
retry_on_status_codes=retry_status_codes,
retry_on_exception_types=retry_types,
non_retry_exception_types=non_retry_types,
retry_on_error_substrings=retry_substrings,
path=path,
)
@property
def is_active(self) -> bool:
return self.enabled and self.max_attempts > 1
def should_retry(self, exc: BaseException) -> bool:
if not self.is_active:
return False
chain: List[tuple[BaseException, set[str], int | None, str]] = []
for error in self._iter_exception_chain(exc):
chain.append(
(
error,
self._exception_name_set(error),
self._extract_status_code(error),
str(error).lower(),
)
)
if self.non_retry_exception_types:
for _, names, _, _ in chain:
if any(name in names for name in self.non_retry_exception_types):
return False
if self.retry_on_exception_types:
for _, names, _, _ in chain:
if any(name in names for name in self.retry_on_exception_types):
return True
if self.retry_on_status_codes:
for _, _, status_code, _ in chain:
if status_code is not None and status_code in self.retry_on_status_codes:
return True
if self.retry_on_error_substrings:
for _, _, _, message in chain:
if message and any(substr in message for substr in self.retry_on_error_substrings):
return True
return False
def _exception_name_set(self, exc: BaseException) -> set[str]:
names: set[str] = set()
for cls in exc.__class__.mro():
names.add(cls.__name__.lower())
names.add(f"{cls.__module__}.{cls.__name__}".lower())
return names
def _extract_status_code(self, exc: BaseException) -> int | None:
for attr in ("status_code", "http_status", "status", "statusCode"):
value = getattr(exc, attr, None)
if isinstance(value, int):
return value
response = getattr(exc, "response", None)
if response is not None:
for attr in ("status_code", "status", "statusCode"):
value = getattr(response, attr, None)
if isinstance(value, int):
return value
return None
def _iter_exception_chain(self, exc: BaseException) -> Iterable[BaseException]:
seen: set[int] = set()
stack: List[BaseException] = [exc]
while stack:
current = stack.pop()
if id(current) in seen:
continue
seen.add(id(current))
yield current
linked: List[BaseException] = []
cause = getattr(current, "__cause__", None)
context = getattr(current, "__context__", None)
if isinstance(cause, BaseException):
linked.append(cause)
if isinstance(context, BaseException):
linked.append(context)
if _BASE_EXCEPTION_GROUP_TYPE is not None and isinstance(current, _BASE_EXCEPTION_GROUP_TYPE):
for exc_item in getattr(current, "exceptions", None) or ():
if isinstance(exc_item, BaseException):
linked.append(exc_item)
stack.extend(linked)
@dataclass
class AgentConfig(BaseConfig):
provider: str
base_url: str
name: str
role: str | None = None
api_key: str | None = None
params: Dict[str, Any] = field(default_factory=dict)
retry: AgentRetryConfig | None = None
input_mode: AgentInputMode = AgentInputMode.MESSAGES
tooling: List[ToolingConfig] = field(default_factory=list)
thinking: ThinkingConfig | None = None
memories: List[MemoryAttachmentConfig] = field(default_factory=list)
skills: AgentSkillsConfig | None = None
# Runtime attributes (attached dynamically)
token_tracker: Any | None = field(default=None, init=False, repr=False)
node_id: str | None = field(default=None, init=False, repr=False)
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "AgentConfig":
mapping = require_mapping(data, path)
provider = require_str(mapping, "provider", path)
base_url = optional_str(mapping, "base_url", path)
name_value = mapping.get("name")
if isinstance(name_value, str) and name_value.strip():
model_name = name_value.strip()
else:
raise ConfigError("model.name must be a non-empty string", extend_path(path, "name"))
role = optional_str(mapping, "role", path)
api_key = optional_str(mapping, "api_key", path)
params = optional_dict(mapping, "params", path) or {}
raw_input_mode = optional_str(mapping, "input_mode", path)
input_mode = AgentInputMode.MESSAGES
if raw_input_mode:
try:
input_mode = AgentInputMode(raw_input_mode.strip().lower())
except ValueError as exc:
raise ConfigError(
"model.input_mode must be 'prompt' or 'messages'",
extend_path(path, "input_mode"),
) from exc
tooling_cfg: List[ToolingConfig] = []
if "tooling" in mapping and mapping["tooling"] is not None:
raw_tooling = mapping["tooling"]
if not isinstance(raw_tooling, list):
raise ConfigError("tooling must be a list", extend_path(path, "tooling"))
for idx, item in enumerate(raw_tooling):
tooling_cfg.append(
ToolingConfig.from_dict(item, path=extend_path(path, f"tooling[{idx}]"))
)
thinking_cfg = None
if "thinking" in mapping and mapping["thinking"] is not None:
thinking_cfg = ThinkingConfig.from_dict(mapping["thinking"], path=extend_path(path, "thinking"))
memories_cfg: List[MemoryAttachmentConfig] = []
if "memories" in mapping and mapping["memories"] is not None:
raw_memories = mapping["memories"]
if not isinstance(raw_memories, list):
raise ConfigError("memories must be a list", extend_path(path, "memories"))
for idx, item in enumerate(raw_memories):
memories_cfg.append(
MemoryAttachmentConfig.from_dict(item, path=extend_path(path, f"memories[{idx}]"))
)
retry_cfg = None
if "retry" in mapping and mapping["retry"] is not None:
retry_cfg = AgentRetryConfig.from_dict(mapping["retry"], path=extend_path(path, "retry"))
skills_cfg = None
if "skills" in mapping and mapping["skills"] is not None:
skills_cfg = AgentSkillsConfig.from_dict(mapping["skills"], path=extend_path(path, "skills"))
return cls(
provider=provider,
base_url=base_url,
name=model_name,
role=role,
api_key=api_key,
params=params,
tooling=tooling_cfg,
thinking=thinking_cfg,
memories=memories_cfg,
skills=skills_cfg,
retry=retry_cfg,
input_mode=input_mode,
path=path,
)
FIELD_SPECS = {
"name": ConfigFieldSpec(
name="name",
display_name="Model Name",
type_hint="str",
required=True,
description="Specific model name e.g. gpt-4o",
),
"role": ConfigFieldSpec(
name="role",
display_name="System Prompt",
type_hint="text",
required=False,
description="Model system prompt",
),
"provider": ConfigFieldSpec(
name="provider",
display_name="Model Provider",
type_hint="str",
required=True,
description="Name of a registered provider (openai, gemini, etc.) that selects the underlying client adapter.",
default="openai",
),
"base_url": ConfigFieldSpec(
name="base_url",
display_name="Base URL",
type_hint="str",
required=False,
description="Override the provider's default endpoint; leave empty to use the built-in base URL.",
advance=True,
default="${BASE_URL}",
),
"api_key": ConfigFieldSpec(
name="api_key",
display_name="API Key",
type_hint="str",
required=False,
description="Credential consumed by the provider client; reference an env var such as ${API_KEY} that matches the selected provider.",
advance=True,
default="${API_KEY}",
),
"params": ConfigFieldSpec(
name="params",
display_name="Call Parameters",
type_hint="dict[str, Any]",
required=False,
default={},
description="Call parameters (temperature, top_p, etc.)",
advance=True,
),
# "input_mode": ConfigFieldSpec( # currently, many features depend on messages mode, so hide this and force messages
# name="input_mode",
# display_name="Input Mode",
# type_hint="enum:AgentInputMode",
# required=False,
# default=AgentInputMode.MESSAGES.value,
# description="Model input mode: messages (default) or prompt",
# enum=[item.value for item in AgentInputMode],
# advance=True,
# enum_options=enum_options_for(AgentInputMode),
# ),
"tooling": ConfigFieldSpec(
name="tooling",
display_name="Tool Configuration",
type_hint="list[ToolingConfig]",
required=False,
description="Bound tool configuration list",
child=ToolingConfig,
advance=True,
),
"thinking": ConfigFieldSpec(
name="thinking",
display_name="Thinking Configuration",
type_hint="ThinkingConfig",
required=False,
description="Thinking process configuration",
child=ThinkingConfig,
advance=True,
),
"memories": ConfigFieldSpec(
name="memories",
display_name="Memory Attachments",
type_hint="list[MemoryAttachmentConfig]",
required=False,
description="Associated memory references",
child=MemoryAttachmentConfig,
advance=True,
),
"skills": ConfigFieldSpec(
name="skills",
display_name="Agent Skills",
type_hint="AgentSkillsConfig",
required=False,
description="Agent Skills allowlist and built-in skill activation/file-read tools.",
child=AgentSkillsConfig,
advance=True,
),
"retry": ConfigFieldSpec(
name="retry",
display_name="Retry Policy",
type_hint="AgentRetryConfig",
required=False,
description="Automatic retry policy for this model",
child=AgentRetryConfig,
advance=True,
),
}
@classmethod
def field_specs(cls) -> Dict[str, ConfigFieldSpec]:
specs = super().field_specs()
provider_spec = specs.get("provider")
if provider_spec:
enum_spec = cls._apply_provider_enum(provider_spec)
specs["provider"] = enum_spec
return specs
@staticmethod
def _apply_provider_enum(provider_spec: ConfigFieldSpec) -> ConfigFieldSpec:
provider_names, metadata = AgentConfig._provider_registry_snapshot()
if not provider_names:
return provider_spec
enum_options: List[EnumOption] = []
for name in provider_names:
meta = metadata.get(name) or {}
label = meta.get("label") or titleize(name)
enum_options.append(
EnumOption(
value=name,
label=label,
description=meta.get("summary"),
)
)
default_value = provider_spec.default
if not default_value or default_value not in provider_names:
default_value = AgentConfig._preferred_provider_default(provider_names)
return replace(
provider_spec,
enum=provider_names,
enum_options=enum_options,
default=default_value,
)
@staticmethod
def _preferred_provider_default(provider_names: List[str]) -> str:
if "openai" in provider_names:
return "openai"
return provider_names[0]
@staticmethod
def _provider_registry_snapshot() -> tuple[List[str], Dict[str, Dict[str, Any]]]:
specs = iter_model_provider_schemas()
names = list(specs.keys())
metadata: Dict[str, Dict[str, Any]] = {}
for name, spec in specs.items():
metadata[name] = {
"label": spec.label,
"summary": spec.summary,
**(spec.metadata or {}),
}
return names, metadata