2026-01-07 16:24:01 +08:00

465 lines
16 KiB
Python
Executable File

"""Memory-related configuration dataclasses."""
from dataclasses import dataclass, field, replace
from typing import Any, Dict, List, Mapping
from entity.enums import AgentExecFlowStage
from entity.enum_options import enum_options_for, enum_options_from_values
from schema_registry import (
SchemaLookupError,
get_memory_store_schema,
iter_memory_store_schemas,
)
from entity.configs.base import (
BaseConfig,
ConfigError,
ConfigFieldSpec,
ChildKey,
ensure_list,
optional_dict,
optional_str,
require_mapping,
require_str,
extend_path,
)
@dataclass
class EmbeddingConfig(BaseConfig):
provider: str
model: str
api_key: str | None = None
base_url: str | None = None
params: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "EmbeddingConfig":
mapping = require_mapping(data, path)
provider = require_str(mapping, "provider", path)
model = require_str(mapping, "model", path)
api_key = optional_str(mapping, "api_key", path)
base_url = optional_str(mapping, "base_url", path)
params = optional_dict(mapping, "params", path) or {}
return cls(provider=provider, model=model, api_key=api_key, base_url=base_url, params=params, path=path)
FIELD_SPECS = {
"provider": ConfigFieldSpec(
name="provider",
display_name="Embedding Provider",
type_hint="str",
required=True,
default="openai",
description="Embedding provider",
),
"model": ConfigFieldSpec(
name="model",
display_name="Embedding Model",
type_hint="str",
required=True,
default="text-embedding-3-small",
description="Embedding model name",
),
"api_key": ConfigFieldSpec(
name="api_key",
display_name="API Key",
type_hint="str",
required=False,
description="API key",
default="${API_KEY}",
advance=True,
),
"base_url": ConfigFieldSpec(
name="base_url",
display_name="Base URL",
type_hint="str",
required=False,
description="Custom Base URL",
default="${BASE_URL}",
advance=True,
),
"params": ConfigFieldSpec(
name="params",
display_name="Custom Parameters",
type_hint="dict[str, Any]",
required=False,
default={},
description="Embedding parameters (temperature, etc.)",
advance=True,
),
}
@dataclass
class FileSourceConfig(BaseConfig):
source_path: str
file_types: List[str] | None = None
recursive: bool = True
encoding: str = "utf-8"
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "FileSourceConfig":
mapping = require_mapping(data, path)
file_path = require_str(mapping, "path", path)
file_types_value = mapping.get("file_types")
file_types: List[str] | None = None
if file_types_value is not None:
items = ensure_list(file_types_value)
normalized: List[str] = []
for idx, item in enumerate(items):
if not isinstance(item, str):
raise ConfigError("file_types entries must be strings", extend_path(path, f"file_types[{idx}]") )
normalized.append(item)
file_types = normalized
recursive_value = mapping.get("recursive", True)
if not isinstance(recursive_value, bool):
raise ConfigError("recursive must be boolean", extend_path(path, "recursive"))
encoding = optional_str(mapping, "encoding", path) or "utf-8"
return cls(source_path=file_path, file_types=file_types, recursive=recursive_value, encoding=encoding, path=path)
FIELD_SPECS = {
"path": ConfigFieldSpec(
name="path",
display_name="File/Directory Path",
type_hint="str",
required=True,
description="Path to file/directory to be indexed",
),
"file_types": ConfigFieldSpec(
name="file_types",
display_name="File Type Filter",
type_hint="list[str]",
required=False,
description="List of file type suffixes to limit (e.g. .md, .txt)",
),
"recursive": ConfigFieldSpec(
name="recursive",
display_name="Recursive Subdirectories",
type_hint="bool",
required=False,
default=True,
description="Whether to include subdirectories recursively",
advance=True,
),
"encoding": ConfigFieldSpec(
name="encoding",
display_name="File Encoding",
type_hint="str",
required=False,
default="utf-8",
description="Encoding used to read files",
advance=True,
),
}
@dataclass
class SimpleMemoryConfig(BaseConfig):
memory_path: str | None = None
embedding: EmbeddingConfig | None = None
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "SimpleMemoryConfig":
mapping = require_mapping(data, path)
memory_path = optional_str(mapping, "memory_path", path)
embedding_cfg = None
if "embedding" in mapping and mapping["embedding"] is not None:
embedding_cfg = EmbeddingConfig.from_dict(mapping["embedding"], path=extend_path(path, "embedding"))
return cls(memory_path=memory_path, embedding=embedding_cfg, path=path)
FIELD_SPECS = {
"memory_path": ConfigFieldSpec(
name="memory_path",
display_name="Memory File Path",
type_hint="str",
required=False,
description="Simple memory file path",
advance=True,
),
"embedding": ConfigFieldSpec(
name="embedding",
display_name="Embedding Configuration",
type_hint="EmbeddingConfig",
required=False,
description="Optional embedding configuration",
child=EmbeddingConfig,
),
}
@dataclass
class FileMemoryConfig(BaseConfig):
index_path: str | None = None
file_sources: List[FileSourceConfig] = field(default_factory=list)
embedding: EmbeddingConfig | None = None
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "FileMemoryConfig":
mapping = require_mapping(data, path)
sources_raw = ensure_list(mapping.get("file_sources"))
if not sources_raw:
raise ConfigError("file_sources must contain at least one entry", extend_path(path, "file_sources"))
sources: List[FileSourceConfig] = []
for idx, item in enumerate(sources_raw):
sources.append(FileSourceConfig.from_dict(item, path=extend_path(path, f"file_sources[{idx}]")))
index_path = optional_str(mapping, "index_path", path)
if index_path is None:
index_path = optional_str(mapping, "memory_path", path)
embedding_cfg = None
if "embedding" in mapping and mapping["embedding"] is not None:
embedding_cfg = EmbeddingConfig.from_dict(mapping["embedding"], path=extend_path(path, "embedding"))
return cls(index_path=index_path, file_sources=sources, embedding=embedding_cfg, path=path)
FIELD_SPECS = {
"index_path": ConfigFieldSpec(
name="index_path",
display_name="Index Path",
type_hint="str",
required=False,
description="Vector index storage path",
advance=True,
),
"file_sources": ConfigFieldSpec(
name="file_sources",
display_name="File Source List",
type_hint="list[FileSourceConfig]",
required=True,
description="List of file sources to ingest",
child=FileSourceConfig,
),
"embedding": ConfigFieldSpec(
name="embedding",
display_name="Embedding Configuration",
type_hint="EmbeddingConfig",
required=False,
description="Embedding used for file memory",
child=EmbeddingConfig,
),
}
@dataclass
class BlackboardMemoryConfig(BaseConfig):
memory_path: str | None = None
max_items: int = 1000
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "BlackboardMemoryConfig":
mapping = require_mapping(data, path)
memory_path = optional_str(mapping, "memory_path", path)
max_items_value = mapping.get("max_items", 1000)
if not isinstance(max_items_value, int) or max_items_value <= 0:
raise ConfigError("max_items must be a positive integer", extend_path(path, "max_items"))
return cls(memory_path=memory_path, max_items=max_items_value, path=path)
FIELD_SPECS = {
"memory_path": ConfigFieldSpec(
name="memory_path",
display_name="Blackboard Path",
type_hint="str",
required=False,
description="JSON path for blackboard memory writing. Pass 'auto' to auto-create in working directory, valid for this run only",
default="auto",
advance=True,
),
"max_items": ConfigFieldSpec(
name="max_items",
display_name="Maximum Items",
type_hint="int",
required=False,
default=1000,
description="Maximum number of memory items to retain (trimmed by time)",
advance=True,
),
}
@dataclass
class MemoryStoreConfig(BaseConfig):
name: str
type: str
config: BaseConfig | None = None
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "MemoryStoreConfig":
mapping = require_mapping(data, path)
name = require_str(mapping, "name", path)
store_type = require_str(mapping, "type", path)
try:
schema = get_memory_store_schema(store_type)
except SchemaLookupError as exc:
raise ConfigError(f"unsupported memory store type '{store_type}'", extend_path(path, "type")) from exc
if "config" not in mapping or mapping["config"] is None:
raise ConfigError("memory store requires config block", extend_path(path, "config"))
config_obj = schema.config_cls.from_dict(mapping["config"], path=extend_path(path, "config"))
return cls(name=name, type=store_type, config=config_obj, path=path)
def require_payload(self) -> BaseConfig:
if not self.config:
raise ConfigError("memory store payload missing", extend_path(self.path, "config"))
return self.config
FIELD_SPECS = {
"name": ConfigFieldSpec(
name="name",
display_name="Store Name",
type_hint="str",
required=True,
description="Unique name of the memory store",
),
"type": ConfigFieldSpec(
name="type",
display_name="Store Type",
type_hint="str",
required=True,
description="Memory store type",
),
"config": ConfigFieldSpec(
name="config",
display_name="Store Configuration",
type_hint="object",
required=True,
description="Schema required by the selected store type (simple/file/blackboard/etc.), following that type's required keys.",
),
}
@classmethod
def child_routes(cls) -> Dict[ChildKey, type[BaseConfig]]:
return {
ChildKey(field="config", value=name): schema.config_cls
for name, schema in iter_memory_store_schemas().items()
}
@classmethod
def field_specs(cls) -> Dict[str, ConfigFieldSpec]:
specs = super().field_specs()
type_spec = specs.get("type")
if type_spec:
registrations = iter_memory_store_schemas()
names = list(registrations.keys())
descriptions = {name: schema.summary for name, schema in registrations.items()}
specs["type"] = replace(
type_spec,
enum=names,
enum_options=enum_options_from_values(names, descriptions, preserve_label_case=True),
)
return specs
@dataclass
class MemoryAttachmentConfig(BaseConfig):
name: str
retrieve_stage: List[AgentExecFlowStage] | None = None
top_k: int = 3
similarity_threshold: float = -1.0
read: bool = True
write: bool = True
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "MemoryAttachmentConfig":
mapping = require_mapping(data, path)
name = require_str(mapping, "name", path)
stages_raw = mapping.get("retrieve_stage")
stages: List[AgentExecFlowStage] | None = None
if stages_raw is not None:
stage_list = ensure_list(stages_raw)
parsed: List[AgentExecFlowStage] = []
for idx, item in enumerate(stage_list):
try:
parsed.append(AgentExecFlowStage(item))
except ValueError as exc:
raise ConfigError(
f"retrieve_stage entries must be one of {[stage.value for stage in AgentExecFlowStage]}",
extend_path(path, f"retrieve_stage[{idx}]"),
) from exc
stages = parsed
top_k_value = mapping.get("top_k", 3)
if not isinstance(top_k_value, int) or top_k_value <= 0:
raise ConfigError("top_k must be a positive integer", extend_path(path, "top_k"))
threshold_value = mapping.get("similarity_threshold", -1.0)
if not isinstance(threshold_value, (int, float)):
raise ConfigError("similarity_threshold must be numeric", extend_path(path, "similarity_threshold"))
read_value = mapping.get("read", True)
if not isinstance(read_value, bool):
raise ConfigError("read must be boolean", extend_path(path, "read"))
write_value = mapping.get("write", True)
if not isinstance(write_value, bool):
raise ConfigError("write must be boolean", extend_path(path, "write"))
return cls(
name=name,
retrieve_stage=stages,
top_k=top_k_value,
similarity_threshold=float(threshold_value),
read=read_value,
write=write_value,
path=path,
)
FIELD_SPECS = {
"name": ConfigFieldSpec(
name="name",
display_name="Memory Name",
type_hint="str",
required=True,
description="Name of the referenced memory store",
),
"retrieve_stage": ConfigFieldSpec(
name="retrieve_stage",
display_name="Retrieve Stage",
type_hint="list[AgentExecFlowStage]",
required=False,
description="Execution stages when memory retrieval occurs. If not set, defaults to all stages. NOTE: this config is related to thinking, if the thinking module is not used, this config has only effect on `gen` stage.",
enum=[stage.value for stage in AgentExecFlowStage],
enum_options=enum_options_for(AgentExecFlowStage),
),
"top_k": ConfigFieldSpec(
name="top_k",
display_name="Top K",
type_hint="int",
required=False,
default=3,
description="Number of items to retrieve",
advance=True,
),
"similarity_threshold": ConfigFieldSpec(
name="similarity_threshold",
display_name="Similarity Threshold",
type_hint="float",
required=False,
default=-1.0,
description="Similarity threshold (-1 means no similarity threshold filter)",
advance=True,
),
"read": ConfigFieldSpec(
name="read",
display_name="Allow Read",
type_hint="bool",
required=False,
default=True,
description="Whether to read this memory during execution",
advance=True,
),
"write": ConfigFieldSpec(
name="write",
display_name="Allow Write",
type_hint="bool",
required=False,
default=True,
description="Whether to write back to this memory after execution",
advance=True,
),
}