ChatDev/entity/configs/node/subgraph.py
2026-01-07 16:24:01 +08:00

267 lines
9.5 KiB
Python
Executable File

"""Subgraph node configuration and registry helpers."""
from dataclasses import dataclass, replace
from typing import Any, Dict, Mapping
from entity.enums import LogLevel
from entity.enum_options import enum_options_for, enum_options_from_values
from entity.configs.base import (
BaseConfig,
ConfigError,
ConfigFieldSpec,
ChildKey,
require_mapping,
require_str,
extend_path,
)
from entity.configs.edge.edge import EdgeConfig
from entity.configs.node.memory import MemoryStoreConfig
from utils.registry import Registry, RegistryError
subgraph_source_registry = Registry("subgraph_source")
def register_subgraph_source(
name: str,
*,
config_cls: type[BaseConfig],
description: str | None = None,
) -> None:
"""Register a subgraph source configuration class."""
metadata = {"summary": description} if description else None
subgraph_source_registry.register(name, target=config_cls, metadata=metadata)
def get_subgraph_source_config(name: str) -> type[BaseConfig]:
entry = subgraph_source_registry.get(name)
config_cls = entry.load()
if not isinstance(config_cls, type) or not issubclass(config_cls, BaseConfig):
raise RegistryError(f"Entry '{name}' is not a BaseConfig subclass")
return config_cls
def iter_subgraph_source_registrations() -> Dict[str, type[BaseConfig]]:
return {name: entry.load() for name, entry in subgraph_source_registry.items()}
def iter_subgraph_source_metadata() -> Dict[str, Dict[str, Any]]:
return {name: dict(entry.metadata or {}) for name, entry in subgraph_source_registry.items()}
@dataclass
class SubgraphFileConfig(BaseConfig):
file_path: str
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "SubgraphFileConfig":
mapping = require_mapping(data, path)
file_path = require_str(mapping, "path", path)
return cls(file_path=file_path, path=path)
FIELD_SPECS = {
"path": ConfigFieldSpec(
name="path",
display_name="Subgraph File Path",
type_hint="str",
required=True,
description="Subgraph file path (relative to yaml_instance/ or absolute path)",
),
}
@dataclass
class SubgraphInlineConfig(BaseConfig):
graph: Dict[str, Any]
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "SubgraphInlineConfig":
mapping = require_mapping(data, path)
return cls(graph=dict(mapping), path=path)
def validate(self) -> None:
if "nodes" not in self.graph:
raise ConfigError("subgraph config must define nodes", extend_path(self.path, "nodes"))
if "edges" not in self.graph:
raise ConfigError("subgraph config must define edges", extend_path(self.path, "edges"))
FIELD_SPECS = {
"id": ConfigFieldSpec(
name="id",
display_name="Subgraph ID",
type_hint="str",
required=True,
description="Subgraph identifier",
),
"description": ConfigFieldSpec(
name="description",
display_name="Subgraph Description",
type_hint="str",
required=False,
description="Describe the subgraph's responsibility, trigger conditions, and success criteria so reviewers know when to call it.",
),
"log_level": ConfigFieldSpec(
name="log_level",
display_name="Log Level",
type_hint="enum:LogLevel",
required=False,
default=LogLevel.INFO.value,
enum=[lvl.value for lvl in LogLevel],
description="Subgraph runtime log level",
enum_options=enum_options_for(LogLevel),
),
"is_majority_voting": ConfigFieldSpec(
name="is_majority_voting",
display_name="Majority Voting",
type_hint="bool",
required=False,
default=False,
description="Whether to perform majority voting on node results",
),
"nodes": ConfigFieldSpec(
name="nodes",
display_name="Node List",
type_hint="list[Node]",
required=True,
description="Subgraph node list, must contain at least one node",
),
"edges": ConfigFieldSpec(
name="edges",
display_name="Edge List",
type_hint="list[EdgeConfig]",
required=True,
description="Subgraph edge list",
child=EdgeConfig,
),
"memory": ConfigFieldSpec(
name="memory",
display_name="Memory Stores",
type_hint="list[MemoryStoreConfig]",
required=False,
description="Provide a list of memory stores if this subgraph needs dedicated stores; leave empty to inherit parent graph stores.",
child=MemoryStoreConfig,
),
"vars": ConfigFieldSpec(
name="vars",
display_name="Variables",
type_hint="dict[str, Any]",
required=False,
default={},
description="Variables passed to subgraph nodes",
),
"organization": ConfigFieldSpec(
name="organization",
display_name="Organization",
type_hint="str",
required=False,
description="Subgraph organization/team identifier",
),
"initial_instruction": ConfigFieldSpec(
name="initial_instruction",
display_name="Initial Instruction",
type_hint="str",
required=False,
description="Subgraph level initial instruction",
),
"start": ConfigFieldSpec(
name="start",
display_name="Start Node",
type_hint="str | list[str]",
required=False,
description="Start node ID list (entry list executed at subgraph start; not recommended to edit manually)",
),
"end": ConfigFieldSpec(
name="end",
display_name="End Node",
type_hint="str | list[str]",
required=False,
description="End node ID list (used to collect final subgraph output, not part of execution logic). This is an ordered list: earlier nodes are checked first; the first with output becomes the subgraph output, otherwise continue down the list.",
),
}
@classmethod
def field_specs(cls) -> Dict[str, ConfigFieldSpec]:
specs = super().field_specs()
nodes_spec = specs.get("nodes")
if nodes_spec:
from entity.configs.node.node import Node
specs["nodes"] = replace(nodes_spec, child=Node)
return specs
@dataclass
class SubgraphConfig(BaseConfig):
type: str
config: BaseConfig | None = None
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "SubgraphConfig":
mapping = require_mapping(data, path)
source_type = require_str(mapping, "type", path)
if "vars" in mapping and mapping["vars"]:
raise ConfigError("vars is only allowed at root level (DesignConfig.vars)", extend_path(path, "vars"))
if "config" not in mapping or mapping["config"] is None:
raise ConfigError("subgraph configuration requires 'config' block", extend_path(path, "config"))
try:
config_cls = get_subgraph_source_config(source_type)
except RegistryError as exc:
raise ConfigError(
f"subgraph.type must be one of {list(iter_subgraph_source_registrations().keys())}",
extend_path(path, "type"),
) from exc
config_obj = config_cls.from_dict(mapping["config"], path=extend_path(path, "config"))
return cls(type=source_type, config=config_obj, path=path)
def validate(self) -> None:
if not self.config:
raise ConfigError("subgraph config missing", extend_path(self.path, "config"))
if hasattr(self.config, "validate"):
self.config.validate()
FIELD_SPECS = {
"type": ConfigFieldSpec(
name="type",
display_name="Subgraph Source Type",
type_hint="str",
required=True,
description="Registered subgraph source such as 'config' or 'file' (see subgraph_source_registry).",
),
"config": ConfigFieldSpec(
name="config",
display_name="Subgraph Configuration",
type_hint="object",
required=True,
description="Payload interpreted by the chosen type—for example inline graph schema for 'config' or file path payload for 'file'.",
),
}
@classmethod
def child_routes(cls) -> Dict[ChildKey, type[BaseConfig]]:
return {
ChildKey(field="config", value=name): config_cls
for name, config_cls in iter_subgraph_source_registrations().items()
}
@classmethod
def field_specs(cls) -> Dict[str, ConfigFieldSpec]:
specs = super().field_specs()
type_spec = specs.get("type")
if type_spec:
registrations = iter_subgraph_source_registrations()
metadata = iter_subgraph_source_metadata()
names = list(registrations.keys())
descriptions = {
name: (metadata.get(name) or {}).get("summary") for name in names
}
specs["type"] = replace(
type_spec,
enum=names,
enum_options=enum_options_from_values(names, descriptions, preserve_label_case=True),
)
return specs