mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 19:28:09 +00:00
234 lines
8.3 KiB
Python
Executable File
234 lines
8.3 KiB
Python
Executable File
"""EdgeConditionManager abstraction that unifies edge condition compilation."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Generic, TypeVar
|
|
|
|
from entity.messages import Message, MessageRole
|
|
from runtime.node.executor import ExecutionContext
|
|
from utils.log_manager import LogManager
|
|
from utils.structured_logger import get_server_logger
|
|
from entity.configs.node.node import EdgeLink, Node
|
|
from entity.configs.edge.edge_condition import EdgeConditionTypeConfig
|
|
from utils.function_manager import FunctionManager
|
|
|
|
ConditionEvaluator = Callable[[str], bool]
|
|
|
|
TConfig = TypeVar("TConfig", bound="EdgeConditionTypeConfig")
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ConditionFactoryContext:
|
|
"""Context passed to managers at initialization time."""
|
|
|
|
function_manager: "FunctionManager | None" = None
|
|
log_manager: "LogManager | None" = None
|
|
|
|
|
|
class EdgeConditionManager(Generic[TConfig], ABC):
|
|
"""An abstract base class is required, and all edge condition types must implement the process logic."""
|
|
|
|
def __init__(self, config: TConfig, ctx: ConditionFactoryContext, execution_context: ExecutionContext) -> None:
|
|
self.config = config
|
|
self.ctx = ctx
|
|
self.execution_context = execution_context
|
|
|
|
@abstractmethod
|
|
def process(
|
|
self,
|
|
edge_link: "EdgeLink",
|
|
source_result: Message,
|
|
from_node: "Node",
|
|
log_manager: LogManager,
|
|
) -> None:
|
|
"""The execution logic is implemented by the subclasses."""
|
|
|
|
def transform_payload(
|
|
self,
|
|
payload: Message,
|
|
*,
|
|
source_result: Message,
|
|
from_node: "Node",
|
|
edge_link: "EdgeLink",
|
|
log_manager: LogManager,
|
|
) -> Message | None:
|
|
processor = edge_link.payload_processor
|
|
if not processor:
|
|
return payload
|
|
try:
|
|
return processor.transform(
|
|
payload,
|
|
source_result=source_result,
|
|
from_node=from_node,
|
|
edge_link=edge_link,
|
|
log_manager=log_manager,
|
|
context=self.execution_context,
|
|
)
|
|
except Exception as exc: # pragma: no cover
|
|
error_msg = (
|
|
f"Edge payload processor failed for {from_node.id}->{edge_link.target.id}: {exc}"
|
|
)
|
|
if self.ctx and self.ctx.log_manager:
|
|
self.ctx.log_manager.error(
|
|
error_msg,
|
|
details={
|
|
"processor_type": getattr(edge_link, "process_type", None),
|
|
"processor_metadata": getattr(edge_link, "process_metadata", {}),
|
|
},
|
|
)
|
|
server_logger = get_server_logger()
|
|
server_logger.log_exception(
|
|
exc,
|
|
error_msg,
|
|
processor_type=getattr(edge_link, "process_type", None),
|
|
processor_metadata=getattr(edge_link, "process_metadata", {}),
|
|
)
|
|
return payload
|
|
|
|
# ------------------------------------------------------------------
|
|
# internal helpers
|
|
# ------------------------------------------------------------------
|
|
def _process_with_condition(
|
|
self,
|
|
evaluator: ConditionEvaluator,
|
|
*,
|
|
label: str,
|
|
metadata: dict[str, Any],
|
|
edge_link: "EdgeLink",
|
|
source_result: Message,
|
|
from_node: "Node",
|
|
log_manager: LogManager,
|
|
) -> None:
|
|
target_node = edge_link.target
|
|
log_manager.record_edge_process(from_node.id, target_node.id, edge_link.config)
|
|
|
|
serialized_source = self._payload_to_text(source_result)
|
|
try:
|
|
condition_met = bool(evaluator(serialized_source))
|
|
except Exception as exc: # pragma: no cover
|
|
error_msg = f"Error calling condition '{label}': {exc}"
|
|
log_manager.error(
|
|
error_msg,
|
|
details={
|
|
"condition_type": edge_link.condition_type,
|
|
"condition_metadata": metadata,
|
|
},
|
|
)
|
|
server_logger = get_server_logger()
|
|
server_logger.log_exception(
|
|
exc,
|
|
error_msg,
|
|
condition_type=edge_link.condition_type,
|
|
condition_metadata=metadata,
|
|
)
|
|
condition_met = True
|
|
|
|
if not condition_met:
|
|
log_manager.debug(
|
|
f"Edge condition not met for {from_node.id} -> {target_node.id}, skipping edge processing"
|
|
)
|
|
return
|
|
|
|
log_manager.info(f"Edge condition met for {from_node.id} -> {target_node.id}")
|
|
self._clear_target_context(
|
|
target_node,
|
|
drop_non_keep=getattr(edge_link, "clear_context", False),
|
|
drop_keep=getattr(edge_link, "clear_kept_context", False),
|
|
from_node=from_node,
|
|
log_manager=log_manager,
|
|
)
|
|
|
|
if edge_link.carry_data:
|
|
payload = self._prepare_payload_for_target(source_result, from_node, target_node, edge_link.keep_message)
|
|
payload = self.transform_payload(
|
|
payload,
|
|
source_result=source_result,
|
|
from_node=from_node,
|
|
edge_link=edge_link,
|
|
log_manager=log_manager,
|
|
)
|
|
if payload is None:
|
|
log_manager.debug(
|
|
f"Payload processor dropped message for edge {from_node.id} -> {target_node.id}"
|
|
)
|
|
return
|
|
|
|
# Tag message with dynamic edge info for later processing
|
|
if edge_link.dynamic_config is not None:
|
|
metadata = dict(payload.metadata)
|
|
metadata["_from_dynamic_edge"] = True
|
|
metadata["_dynamic_edge_source"] = from_node.id
|
|
payload.metadata = metadata
|
|
|
|
target_node.append_input(payload)
|
|
log_manager.debug(
|
|
f"Data passed from {from_node.id} to {target_node.id}'s input queue "
|
|
f"(type: {self._describe_payload(payload)})"
|
|
)
|
|
else:
|
|
log_manager.debug(
|
|
f"Edge {from_node.id} -> {target_node.id} does not carry data, skipping data transfer"
|
|
)
|
|
|
|
if edge_link.trigger:
|
|
edge_link.triggered = True
|
|
log_manager.debug(f"Edge {from_node.id} -> {target_node.id} triggered")
|
|
|
|
def _payload_to_text(self, payload: Any) -> str:
|
|
if isinstance(payload, Message):
|
|
return payload.text_content()
|
|
if payload is None:
|
|
return ""
|
|
return str(payload)
|
|
|
|
def _prepare_payload_for_target(
|
|
self,
|
|
payload: Any,
|
|
from_node: Node,
|
|
target_node: Node,
|
|
keep: bool = False,
|
|
) -> Message:
|
|
if not isinstance(payload, Message):
|
|
payload = Message(role=MessageRole.ASSISTANT, content=str(payload))
|
|
cloned = payload.clone()
|
|
if not cloned.preserve_role:
|
|
cloned_role = MessageRole.ASSISTANT if from_node.id == target_node.id else MessageRole.USER
|
|
cloned.role = cloned_role
|
|
metadata = dict(cloned.metadata)
|
|
metadata["source"] = from_node.id
|
|
cloned.metadata = metadata
|
|
if keep:
|
|
cloned.keep = True
|
|
return cloned
|
|
|
|
def _clear_target_context(
|
|
self,
|
|
target_node: Node,
|
|
*,
|
|
drop_non_keep: bool,
|
|
drop_keep: bool,
|
|
from_node: Node,
|
|
log_manager: LogManager,
|
|
) -> None:
|
|
if not drop_non_keep and not drop_keep:
|
|
return
|
|
removed_non_keep, removed_keep = target_node.clear_inputs_by_flag(
|
|
drop_non_keep=drop_non_keep,
|
|
drop_keep=drop_keep,
|
|
)
|
|
if removed_non_keep or removed_keep:
|
|
log_manager.debug(
|
|
f"Cleared target context for edge {from_node.id} -> {target_node.id}",
|
|
details={
|
|
"removed_non_keep": removed_non_keep,
|
|
"removed_keep": removed_keep,
|
|
"drop_non_keep": drop_non_keep,
|
|
"drop_keep": drop_keep,
|
|
},
|
|
)
|
|
|
|
def _describe_payload(self, payload: Any) -> str:
|
|
if isinstance(payload, Message):
|
|
return f"message:{payload.role.value}"
|
|
return type(payload).__name__
|