mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
444 lines
15 KiB
Python
Executable File
444 lines
15 KiB
Python
Executable File
"""Shared dynamic configuration classes for both node and edge level execution.
|
|
|
|
This module contains the base classes used by both node-level and edge-level
|
|
dynamic execution configurations to avoid circular imports.
|
|
"""
|
|
|
|
from dataclasses import dataclass, fields, replace
|
|
from typing import Any, ClassVar, Dict, Mapping, Optional, Type, TypeVar
|
|
|
|
from entity.configs.base import (
|
|
BaseConfig,
|
|
ChildKey,
|
|
ConfigError,
|
|
ConfigFieldSpec,
|
|
extend_path,
|
|
optional_bool,
|
|
optional_str,
|
|
require_mapping,
|
|
require_str,
|
|
)
|
|
from entity.enum_options import enum_options_from_values
|
|
|
|
|
|
def _serialize_config(config: BaseConfig) -> Dict[str, Any]:
|
|
"""Serialize a config to dict, excluding the path field."""
|
|
payload: Dict[str, Any] = {}
|
|
for field_obj in fields(config):
|
|
if field_obj.name == "path":
|
|
continue
|
|
payload[field_obj.name] = getattr(config, field_obj.name)
|
|
return payload
|
|
|
|
|
|
class SplitTypeConfig(BaseConfig):
|
|
"""Base helper class for split type configs."""
|
|
|
|
def display_label(self) -> str:
|
|
return self.__class__.__name__
|
|
|
|
def to_external_value(self) -> Any:
|
|
return _serialize_config(self)
|
|
|
|
|
|
@dataclass
|
|
class MessageSplitConfig(SplitTypeConfig):
|
|
"""Configuration for message-based splitting.
|
|
|
|
Each input message becomes one execution unit. No additional configuration needed.
|
|
"""
|
|
|
|
FIELD_SPECS: ClassVar[Dict[str, ConfigFieldSpec]] = {}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Mapping[str, Any] | None, *, path: str) -> "MessageSplitConfig":
|
|
# No config needed for message split
|
|
return cls(path=path)
|
|
|
|
def display_label(self) -> str:
|
|
return "message"
|
|
|
|
|
|
_NO_MATCH_DESCRIPTIONS = {
|
|
"pass": "Leave the content unchanged when no match is found.",
|
|
"empty": "Return empty content when no match is found.",
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class RegexSplitConfig(SplitTypeConfig):
|
|
"""Configuration for regex-based splitting.
|
|
|
|
Split content by regex pattern matches. Each match becomes one execution unit.
|
|
|
|
Attributes:
|
|
pattern: Python regular expression used to split content.
|
|
group: Capture group name or index. Defaults to the entire match (group 0).
|
|
case_sensitive: Whether the regex should be case sensitive.
|
|
multiline: Enable multiline mode (re.MULTILINE).
|
|
dotall: Enable dotall mode (re.DOTALL).
|
|
on_no_match: Behavior when no match is found.
|
|
"""
|
|
|
|
pattern: str = ""
|
|
group: str | int | None = None
|
|
case_sensitive: bool = True
|
|
multiline: bool = False
|
|
dotall: bool = False
|
|
on_no_match: str = "pass"
|
|
|
|
FIELD_SPECS = {
|
|
"pattern": ConfigFieldSpec(
|
|
name="pattern",
|
|
display_name="Regex Pattern",
|
|
type_hint="str",
|
|
required=True,
|
|
description="Python regular expression used to split content.",
|
|
),
|
|
"group": ConfigFieldSpec(
|
|
name="group",
|
|
display_name="Capture Group",
|
|
type_hint="str",
|
|
required=False,
|
|
description="Capture group name or index. Defaults to the entire match (group 0).",
|
|
),
|
|
"case_sensitive": ConfigFieldSpec(
|
|
name="case_sensitive",
|
|
display_name="Case Sensitive",
|
|
type_hint="bool",
|
|
required=False,
|
|
default=True,
|
|
description="Whether the regex should be case sensitive.",
|
|
),
|
|
"multiline": ConfigFieldSpec(
|
|
name="multiline",
|
|
display_name="Multiline Flag",
|
|
type_hint="bool",
|
|
required=False,
|
|
default=False,
|
|
description="Enable multiline mode (re.MULTILINE).",
|
|
advance=True,
|
|
),
|
|
"dotall": ConfigFieldSpec(
|
|
name="dotall",
|
|
display_name="Dotall Flag",
|
|
type_hint="bool",
|
|
required=False,
|
|
default=False,
|
|
description="Enable dotall mode (re.DOTALL).",
|
|
advance=True,
|
|
),
|
|
"on_no_match": ConfigFieldSpec(
|
|
name="on_no_match",
|
|
display_name="No Match Behavior",
|
|
type_hint="enum",
|
|
required=False,
|
|
default="pass",
|
|
enum=["pass", "empty"],
|
|
description="Behavior when no match is found.",
|
|
enum_options=enum_options_from_values(
|
|
list(_NO_MATCH_DESCRIPTIONS.keys()),
|
|
_NO_MATCH_DESCRIPTIONS,
|
|
preserve_label_case=True,
|
|
),
|
|
advance=True,
|
|
),
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "RegexSplitConfig":
|
|
mapping = require_mapping(data, path)
|
|
pattern = require_str(mapping, "pattern", path, allow_empty=False)
|
|
|
|
group_value = mapping.get("group")
|
|
group_normalized: str | int | None = None
|
|
if group_value is not None:
|
|
if isinstance(group_value, int):
|
|
group_normalized = group_value
|
|
elif isinstance(group_value, str):
|
|
if group_value.isdigit():
|
|
group_normalized = int(group_value)
|
|
else:
|
|
group_normalized = group_value
|
|
else:
|
|
raise ConfigError("group must be str or int", extend_path(path, "group"))
|
|
|
|
case_sensitive = optional_bool(mapping, "case_sensitive", path, default=True)
|
|
multiline = optional_bool(mapping, "multiline", path, default=False)
|
|
dotall = optional_bool(mapping, "dotall", path, default=False)
|
|
on_no_match = optional_str(mapping, "on_no_match", path) or "pass"
|
|
|
|
if on_no_match not in {"pass", "empty"}:
|
|
raise ConfigError("on_no_match must be 'pass' or 'empty'", extend_path(path, "on_no_match"))
|
|
|
|
return cls(
|
|
pattern=pattern,
|
|
group=group_normalized,
|
|
case_sensitive=True if case_sensitive is None else bool(case_sensitive),
|
|
multiline=bool(multiline) if multiline is not None else False,
|
|
dotall=bool(dotall) if dotall is not None else False,
|
|
on_no_match=on_no_match,
|
|
path=path,
|
|
)
|
|
|
|
def display_label(self) -> str:
|
|
return f"regex({self.pattern})"
|
|
|
|
|
|
@dataclass
|
|
class JsonPathSplitConfig(SplitTypeConfig):
|
|
"""Configuration for JSON path-based splitting.
|
|
|
|
Split content by extracting array items from JSON using a path expression.
|
|
Each array item becomes one execution unit.
|
|
|
|
Attributes:
|
|
json_path: Simple dot-notation path to array (e.g., 'items', 'data.results').
|
|
"""
|
|
|
|
json_path: str = ""
|
|
|
|
FIELD_SPECS = {
|
|
"json_path": ConfigFieldSpec(
|
|
name="json_path",
|
|
display_name="JSON Path",
|
|
type_hint="str",
|
|
required=True,
|
|
description="Simple dot-notation path to array (e.g., 'items', 'data.results').",
|
|
),
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "JsonPathSplitConfig":
|
|
mapping = require_mapping(data, path)
|
|
json_path_value = require_str(mapping, "json_path", path, allow_empty=True)
|
|
return cls(json_path=json_path_value, path=path)
|
|
|
|
def display_label(self) -> str:
|
|
return f"json_path({self.json_path})"
|
|
|
|
|
|
# Registry for split types
|
|
_SPLIT_TYPE_REGISTRY: Dict[str, Dict[str, Any]] = {
|
|
"message": {
|
|
"config_cls": MessageSplitConfig,
|
|
"summary": "Each input message becomes one unit",
|
|
},
|
|
"regex": {
|
|
"config_cls": RegexSplitConfig,
|
|
"summary": "Split by regex pattern matches",
|
|
},
|
|
"json_path": {
|
|
"config_cls": JsonPathSplitConfig,
|
|
"summary": "Split by JSON array path",
|
|
},
|
|
}
|
|
|
|
|
|
def get_split_type_config(name: str) -> Type[SplitTypeConfig]:
|
|
"""Get the config class for a split type."""
|
|
entry = _SPLIT_TYPE_REGISTRY.get(name)
|
|
if not entry:
|
|
raise ConfigError(f"Unknown split type: {name}", None)
|
|
return entry["config_cls"]
|
|
|
|
|
|
def iter_split_type_registrations() -> Dict[str, Type[SplitTypeConfig]]:
|
|
"""Iterate over all registered split types."""
|
|
return {name: entry["config_cls"] for name, entry in _SPLIT_TYPE_REGISTRY.items()}
|
|
|
|
|
|
def iter_split_type_metadata() -> Dict[str, Dict[str, Any]]:
|
|
"""Iterate over split type metadata."""
|
|
return {name: {"summary": entry.get("summary")} for name, entry in _SPLIT_TYPE_REGISTRY.items()}
|
|
|
|
|
|
TSplitConfig = TypeVar("TSplitConfig", bound=SplitTypeConfig)
|
|
|
|
|
|
@dataclass
|
|
class SplitConfig(BaseConfig):
|
|
"""Configuration for how to split inputs into execution units.
|
|
|
|
Attributes:
|
|
type: Split strategy type (message, regex, json_path)
|
|
config: Type-specific configuration
|
|
"""
|
|
type: str = "message"
|
|
config: SplitTypeConfig | None = None
|
|
|
|
FIELD_SPECS = {
|
|
"type": ConfigFieldSpec(
|
|
name="type",
|
|
display_name="Split Type",
|
|
type_hint="str",
|
|
required=True,
|
|
default="message",
|
|
description="Strategy for splitting inputs into parallel execution units",
|
|
),
|
|
"config": ConfigFieldSpec(
|
|
name="config",
|
|
display_name="Split Config",
|
|
type_hint="object",
|
|
required=False,
|
|
description="Type-specific split configuration",
|
|
),
|
|
}
|
|
|
|
@classmethod
|
|
def child_routes(cls) -> Dict[ChildKey, Type[BaseConfig]]:
|
|
return {
|
|
ChildKey(field="config", value=name): config_cls
|
|
for name, config_cls in iter_split_type_registrations().items()
|
|
}
|
|
|
|
@classmethod
|
|
def field_specs(cls) -> Dict[str, ConfigFieldSpec]:
|
|
specs = super().field_specs()
|
|
type_spec = specs.get("type")
|
|
if type_spec:
|
|
registrations = iter_split_type_registrations()
|
|
metadata = iter_split_type_metadata()
|
|
type_names = list(registrations.keys())
|
|
descriptions = {name: (metadata.get(name) or {}).get("summary") for name in type_names}
|
|
specs["type"] = replace(
|
|
type_spec,
|
|
enum=type_names,
|
|
enum_options=enum_options_from_values(type_names, descriptions),
|
|
)
|
|
return specs
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Mapping[str, Any] | None, *, path: str) -> "SplitConfig":
|
|
if data is None:
|
|
# Default to message split
|
|
return cls(type="message", config=MessageSplitConfig(path=extend_path(path, "config")), path=path)
|
|
|
|
mapping = require_mapping(data, path)
|
|
split_type = optional_str(mapping, "type", path) or "message"
|
|
|
|
if split_type not in _SPLIT_TYPE_REGISTRY:
|
|
raise ConfigError(
|
|
f"split type must be one of {list(_SPLIT_TYPE_REGISTRY.keys())}, got '{split_type}'",
|
|
extend_path(path, "type"),
|
|
)
|
|
|
|
config_cls = get_split_type_config(split_type)
|
|
config_data = mapping.get("config")
|
|
config_path = extend_path(path, "config")
|
|
|
|
# For message type, config is optional
|
|
if split_type == "message":
|
|
config = config_cls.from_dict(config_data, path=config_path)
|
|
else:
|
|
if config_data is None:
|
|
raise ConfigError(f"{split_type} split requires 'config' field", path)
|
|
config = config_cls.from_dict(config_data, path=config_path)
|
|
|
|
return cls(type=split_type, config=config, path=path)
|
|
|
|
def display_label(self) -> str:
|
|
if self.config:
|
|
return self.config.display_label()
|
|
return self.type
|
|
|
|
def to_external_value(self) -> Any:
|
|
return {
|
|
"type": self.type,
|
|
"config": self.config.to_external_value() if self.config else {},
|
|
}
|
|
|
|
def as_split_config(self, expected_type: Type[TSplitConfig]) -> TSplitConfig | None:
|
|
"""Return the nested config if it matches the expected type."""
|
|
if isinstance(self.config, expected_type):
|
|
return self.config
|
|
return None
|
|
|
|
# Convenience properties for backward compatibility and easy access
|
|
@property
|
|
def pattern(self) -> Optional[str]:
|
|
"""Get regex pattern if this is a regex split."""
|
|
if isinstance(self.config, RegexSplitConfig):
|
|
return self.config.pattern
|
|
return None
|
|
|
|
@property
|
|
def json_path(self) -> Optional[str]:
|
|
"""Get json_path if this is a json_path split."""
|
|
if isinstance(self.config, JsonPathSplitConfig):
|
|
return self.config.json_path
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class MapDynamicConfig(BaseConfig):
|
|
"""Configuration for Map dynamic mode (fan-out only).
|
|
|
|
Map mode is similar to passthrough - minimal config required.
|
|
|
|
Attributes:
|
|
max_parallel: Maximum concurrent executions
|
|
"""
|
|
max_parallel: int = 10
|
|
|
|
FIELD_SPECS = {
|
|
"max_parallel": ConfigFieldSpec(
|
|
name="max_parallel",
|
|
display_name="Max Parallel",
|
|
type_hint="int",
|
|
required=False,
|
|
default=10,
|
|
description="Maximum number of parallel executions",
|
|
),
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Mapping[str, Any] | None, *, path: str) -> "MapDynamicConfig":
|
|
if data is None:
|
|
return cls(path=path)
|
|
mapping = require_mapping(data, path)
|
|
max_parallel = int(mapping.get("max_parallel", 10))
|
|
return cls(max_parallel=max_parallel, path=path)
|
|
|
|
|
|
@dataclass
|
|
class TreeDynamicConfig(BaseConfig):
|
|
"""Configuration for Tree dynamic mode (fan-out and reduce).
|
|
|
|
Attributes:
|
|
group_size: Number of items per group in reduction
|
|
max_parallel: Maximum concurrent executions per layer
|
|
"""
|
|
group_size: int = 3
|
|
max_parallel: int = 10
|
|
|
|
FIELD_SPECS = {
|
|
"group_size": ConfigFieldSpec(
|
|
name="group_size",
|
|
display_name="Group Size",
|
|
type_hint="int",
|
|
required=False,
|
|
default=3,
|
|
description="Number of items per group during reduction",
|
|
),
|
|
"max_parallel": ConfigFieldSpec(
|
|
name="max_parallel",
|
|
display_name="Max Parallel",
|
|
type_hint="int",
|
|
required=False,
|
|
default=10,
|
|
description="Maximum concurrent executions per layer",
|
|
),
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Mapping[str, Any] | None, *, path: str) -> "TreeDynamicConfig":
|
|
if data is None:
|
|
return cls(path=path)
|
|
mapping = require_mapping(data, path)
|
|
group_size = int(mapping.get("group_size", 3))
|
|
if group_size < 2:
|
|
raise ConfigError("group_size must be at least 2", extend_path(path, "group_size"))
|
|
max_parallel = int(mapping.get("max_parallel", 10))
|
|
return cls(group_size=group_size, max_parallel=max_parallel, path=path)
|