ChatDev/workflow/runtime/execution_strategy.py
2026-01-07 16:24:01 +08:00

150 lines
4.8 KiB
Python
Executable File

"""Execution strategies for different graph topologies."""
from collections import Counter
from typing import Callable, Dict, List, Sequence
from entity.configs import Node
from entity.messages import Message
from utils.log_manager import LogManager
from workflow.executor.dag_executor import DAGExecutor
from workflow.executor.cycle_executor import CycleExecutor
from workflow.executor.parallel_executor import ParallelExecutor
class DagExecutionStrategy:
"""Executes acyclic graphs using the DAGExecutor."""
def __init__(
self,
log_manager: LogManager,
nodes: Dict[str, Node],
layers: List[List[str]],
execute_node_func: Callable[[Node], None],
) -> None:
self.log_manager = log_manager
self.nodes = nodes
self.layers = layers
self.execute_node_func = execute_node_func
def run(self) -> None:
dag_executor = DAGExecutor(
log_manager=self.log_manager,
nodes=self.nodes,
layers=self.layers,
execute_node_func=self.execute_node_func,
)
dag_executor.execute()
class CycleExecutionStrategy:
"""Executes graphs containing cycles via CycleExecutor."""
def __init__(
self,
log_manager: LogManager,
nodes: Dict[str, Node],
cycle_execution_order: List[Dict[str, str]],
cycle_manager,
execute_node_func: Callable[[Node], None],
) -> None:
self.log_manager = log_manager
self.nodes = nodes
self.cycle_execution_order = cycle_execution_order
self.cycle_manager = cycle_manager
self.execute_node_func = execute_node_func
def run(self) -> None:
cycle_executor = CycleExecutor(
log_manager=self.log_manager,
nodes=self.nodes,
cycle_execution_order=self.cycle_execution_order,
cycle_manager=self.cycle_manager,
execute_node_func=self.execute_node_func,
)
cycle_executor.execute()
class MajorityVoteStrategy:
"""Executes graphs configured for majority voting (no edges)."""
def __init__(
self,
log_manager: LogManager,
nodes: Dict[str, Node],
initial_messages: Sequence[Message],
execute_node_func: Callable[[Node], None],
payload_to_text_func: Callable[[object], str],
) -> None:
self.log_manager = log_manager
self.nodes = nodes
self.initial_messages = initial_messages
self.execute_node_func = execute_node_func
self.payload_to_text = payload_to_text_func
def run(self) -> str:
self.log_manager.info("Executing graph with majority voting approach")
all_nodes = list(self.nodes.values())
if not all_nodes:
self.log_manager.error("No nodes to execute in majority voting mode")
return ""
for node in all_nodes:
node.clear_input()
for message in self.initial_messages:
node.append_input(message.clone())
node_ids = [node.id for node in all_nodes]
def _execute(node_id: str) -> None:
self.execute_node_func(self.nodes[node_id])
parallel_executor = ParallelExecutor(self.log_manager, self.nodes)
parallel_executor.execute_nodes_parallel(node_ids, _execute)
return self._collect_majority_result()
def _collect_majority_result(self) -> str:
node_outputs: List[Dict[str, str]] = []
for node_id, node in self.nodes.items():
if node.output:
output_text = self.payload_to_text(node.output[-1])
else:
output_text = ""
node_outputs.append(
{
"node_id": node_id,
"node_type": node.node_type,
"output": output_text,
}
)
output_values = [item["output"] for item in node_outputs]
output_counts = Counter(output_values)
non_empty_outputs = [value for value in output_values if value.strip()]
if non_empty_outputs:
output_counts = Counter(non_empty_outputs)
if not output_counts:
self.log_manager.warning("No outputs available for majority voting")
return ""
majority_output, count = output_counts.most_common(1)[0]
self.log_manager.info(
"Majority output determined",
details={"result": majority_output, "votes": count},
)
self.log_manager.info(
"All node outputs",
details={
"outputs": [
(
item["node_id"],
item["output"][:50] + "..." if len(item["output"]) > 50 else item["output"],
)
for item in node_outputs
]
},
)
return majority_output