mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 19:28:09 +00:00
117 lines
3.4 KiB
Python
Executable File
117 lines
3.4 KiB
Python
Executable File
"""Abstract base classes for agent providers."""
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from entity.configs import AgentConfig
|
|
from entity.messages import Message
|
|
from schema_registry import register_model_provider_schema
|
|
from entity.tool_spec import ToolSpec
|
|
from runtime.node.agent.providers.response import ModelResponse
|
|
from utils.token_tracker import TokenUsage
|
|
from utils.registry import Registry
|
|
|
|
|
|
class ModelProvider(ABC):
|
|
"""Abstract base class for all agent providers."""
|
|
|
|
def __init__(self, config: AgentConfig):
|
|
"""
|
|
Initialize the agent provider with configuration.
|
|
|
|
Args:
|
|
config: Agent configuration instance
|
|
"""
|
|
self.config = config
|
|
self.base_url = config.base_url
|
|
self.api_key = config.api_key
|
|
self.model_name = config.name if isinstance(config.name, str) else str(config.name)
|
|
self.provider = config.provider
|
|
self.params = config.params or {}
|
|
|
|
@abstractmethod
|
|
def create_client(self):
|
|
"""
|
|
Create and return the appropriate client for this provider.
|
|
|
|
Returns:
|
|
Client instance for making API calls
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def call_model(
|
|
self,
|
|
client,
|
|
conversation: List[Message],
|
|
timeline: List[Any],
|
|
tool_specs: Optional[List[ToolSpec]] = None,
|
|
**kwargs,
|
|
) -> ModelResponse:
|
|
"""
|
|
Call the model with the given messages and parameters.
|
|
|
|
Args:
|
|
client: Provider-specific client instance
|
|
conversation: List of messages in the conversation
|
|
tool_specs: Tool specifications available for this call
|
|
**kwargs: Additional parameters for the model call
|
|
|
|
Returns:
|
|
ModelResponse containing content and potentially tool calls
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def extract_token_usage(self, response: Any) -> TokenUsage:
|
|
"""
|
|
Extract token usage from the API response.
|
|
|
|
Args:
|
|
response: Raw API response from the model call
|
|
|
|
Returns:
|
|
TokenUsage instance with token counts
|
|
"""
|
|
pass
|
|
|
|
|
|
_provider_registry = Registry("agent_provider")
|
|
|
|
|
|
class ProviderRegistry:
|
|
"""Registry facade for agent providers."""
|
|
|
|
@classmethod
|
|
def register(
|
|
cls,
|
|
name: str,
|
|
provider_class: type,
|
|
*,
|
|
label: str | None = None,
|
|
summary: str | None = None,
|
|
) -> None:
|
|
metadata = {
|
|
"label": label,
|
|
"summary": summary,
|
|
}
|
|
# Drop None values so schema consumers don't need to filter.
|
|
metadata = {key: value for key, value in metadata.items() if value is not None}
|
|
_provider_registry.register(name, target=provider_class, metadata=metadata)
|
|
register_model_provider_schema(name, label=label, summary=summary)
|
|
|
|
@classmethod
|
|
def get_provider(cls, name: str) -> type | None:
|
|
try:
|
|
entry = _provider_registry.get(name)
|
|
except Exception:
|
|
return None
|
|
return entry.load()
|
|
|
|
@classmethod
|
|
def list_providers(cls) -> List[str]:
|
|
return list(_provider_registry.names())
|
|
|
|
@classmethod
|
|
def iter_metadata(cls) -> Dict[str, Dict[str, Any]]:
|
|
return {name: dict(entry.metadata or {}) for name, entry in _provider_registry.items()}
|