2026-01-07 16:24:01 +08:00

117 lines
3.4 KiB
Python
Executable File

"""Abstract base classes for agent providers."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from entity.configs import AgentConfig
from entity.messages import Message
from schema_registry import register_model_provider_schema
from entity.tool_spec import ToolSpec
from runtime.node.agent.providers.response import ModelResponse
from utils.token_tracker import TokenUsage
from utils.registry import Registry
class ModelProvider(ABC):
"""Abstract base class for all agent providers."""
def __init__(self, config: AgentConfig):
"""
Initialize the agent provider with configuration.
Args:
config: Agent configuration instance
"""
self.config = config
self.base_url = config.base_url
self.api_key = config.api_key
self.model_name = config.name if isinstance(config.name, str) else str(config.name)
self.provider = config.provider
self.params = config.params or {}
@abstractmethod
def create_client(self):
"""
Create and return the appropriate client for this provider.
Returns:
Client instance for making API calls
"""
pass
@abstractmethod
def call_model(
self,
client,
conversation: List[Message],
timeline: List[Any],
tool_specs: Optional[List[ToolSpec]] = None,
**kwargs,
) -> ModelResponse:
"""
Call the model with the given messages and parameters.
Args:
client: Provider-specific client instance
conversation: List of messages in the conversation
tool_specs: Tool specifications available for this call
**kwargs: Additional parameters for the model call
Returns:
ModelResponse containing content and potentially tool calls
"""
pass
@abstractmethod
def extract_token_usage(self, response: Any) -> TokenUsage:
"""
Extract token usage from the API response.
Args:
response: Raw API response from the model call
Returns:
TokenUsage instance with token counts
"""
pass
_provider_registry = Registry("agent_provider")
class ProviderRegistry:
"""Registry facade for agent providers."""
@classmethod
def register(
cls,
name: str,
provider_class: type,
*,
label: str | None = None,
summary: str | None = None,
) -> None:
metadata = {
"label": label,
"summary": summary,
}
# Drop None values so schema consumers don't need to filter.
metadata = {key: value for key, value in metadata.items() if value is not None}
_provider_registry.register(name, target=provider_class, metadata=metadata)
register_model_provider_schema(name, label=label, summary=summary)
@classmethod
def get_provider(cls, name: str) -> type | None:
try:
entry = _provider_registry.get(name)
except Exception:
return None
return entry.load()
@classmethod
def list_providers(cls) -> List[str]:
return list(_provider_registry.names())
@classmethod
def iter_metadata(cls) -> Dict[str, Dict[str, Any]]:
return {name: dict(entry.metadata or {}) for name, entry in _provider_registry.items()}