mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
150 lines
4.8 KiB
Python
Executable File
150 lines
4.8 KiB
Python
Executable File
"""Execution strategies for different graph topologies."""
|
|
|
|
from collections import Counter
|
|
from typing import Callable, Dict, List, Sequence
|
|
|
|
from entity.configs import Node
|
|
from entity.messages import Message
|
|
from utils.log_manager import LogManager
|
|
from workflow.executor.dag_executor import DAGExecutor
|
|
from workflow.executor.cycle_executor import CycleExecutor
|
|
from workflow.executor.parallel_executor import ParallelExecutor
|
|
|
|
|
|
class DagExecutionStrategy:
|
|
"""Executes acyclic graphs using the DAGExecutor."""
|
|
|
|
def __init__(
|
|
self,
|
|
log_manager: LogManager,
|
|
nodes: Dict[str, Node],
|
|
layers: List[List[str]],
|
|
execute_node_func: Callable[[Node], None],
|
|
) -> None:
|
|
self.log_manager = log_manager
|
|
self.nodes = nodes
|
|
self.layers = layers
|
|
self.execute_node_func = execute_node_func
|
|
|
|
def run(self) -> None:
|
|
dag_executor = DAGExecutor(
|
|
log_manager=self.log_manager,
|
|
nodes=self.nodes,
|
|
layers=self.layers,
|
|
execute_node_func=self.execute_node_func,
|
|
)
|
|
dag_executor.execute()
|
|
|
|
|
|
class CycleExecutionStrategy:
|
|
"""Executes graphs containing cycles via CycleExecutor."""
|
|
|
|
def __init__(
|
|
self,
|
|
log_manager: LogManager,
|
|
nodes: Dict[str, Node],
|
|
cycle_execution_order: List[Dict[str, str]],
|
|
cycle_manager,
|
|
execute_node_func: Callable[[Node], None],
|
|
) -> None:
|
|
self.log_manager = log_manager
|
|
self.nodes = nodes
|
|
self.cycle_execution_order = cycle_execution_order
|
|
self.cycle_manager = cycle_manager
|
|
self.execute_node_func = execute_node_func
|
|
|
|
def run(self) -> None:
|
|
cycle_executor = CycleExecutor(
|
|
log_manager=self.log_manager,
|
|
nodes=self.nodes,
|
|
cycle_execution_order=self.cycle_execution_order,
|
|
cycle_manager=self.cycle_manager,
|
|
execute_node_func=self.execute_node_func,
|
|
)
|
|
cycle_executor.execute()
|
|
|
|
|
|
class MajorityVoteStrategy:
|
|
"""Executes graphs configured for majority voting (no edges)."""
|
|
|
|
def __init__(
|
|
self,
|
|
log_manager: LogManager,
|
|
nodes: Dict[str, Node],
|
|
initial_messages: Sequence[Message],
|
|
execute_node_func: Callable[[Node], None],
|
|
payload_to_text_func: Callable[[object], str],
|
|
) -> None:
|
|
self.log_manager = log_manager
|
|
self.nodes = nodes
|
|
self.initial_messages = initial_messages
|
|
self.execute_node_func = execute_node_func
|
|
self.payload_to_text = payload_to_text_func
|
|
|
|
def run(self) -> str:
|
|
self.log_manager.info("Executing graph with majority voting approach")
|
|
all_nodes = list(self.nodes.values())
|
|
if not all_nodes:
|
|
self.log_manager.error("No nodes to execute in majority voting mode")
|
|
return ""
|
|
|
|
for node in all_nodes:
|
|
node.clear_input()
|
|
for message in self.initial_messages:
|
|
node.append_input(message.clone())
|
|
|
|
node_ids = [node.id for node in all_nodes]
|
|
|
|
def _execute(node_id: str) -> None:
|
|
self.execute_node_func(self.nodes[node_id])
|
|
|
|
parallel_executor = ParallelExecutor(self.log_manager, self.nodes)
|
|
parallel_executor.execute_nodes_parallel(node_ids, _execute)
|
|
|
|
return self._collect_majority_result()
|
|
|
|
def _collect_majority_result(self) -> str:
|
|
node_outputs: List[Dict[str, str]] = []
|
|
for node_id, node in self.nodes.items():
|
|
if node.output:
|
|
output_text = self.payload_to_text(node.output[-1])
|
|
else:
|
|
output_text = ""
|
|
node_outputs.append(
|
|
{
|
|
"node_id": node_id,
|
|
"node_type": node.node_type,
|
|
"output": output_text,
|
|
}
|
|
)
|
|
|
|
output_values = [item["output"] for item in node_outputs]
|
|
output_counts = Counter(output_values)
|
|
|
|
non_empty_outputs = [value for value in output_values if value.strip()]
|
|
if non_empty_outputs:
|
|
output_counts = Counter(non_empty_outputs)
|
|
|
|
if not output_counts:
|
|
self.log_manager.warning("No outputs available for majority voting")
|
|
return ""
|
|
|
|
majority_output, count = output_counts.most_common(1)[0]
|
|
self.log_manager.info(
|
|
"Majority output determined",
|
|
details={"result": majority_output, "votes": count},
|
|
)
|
|
self.log_manager.info(
|
|
"All node outputs",
|
|
details={
|
|
"outputs": [
|
|
(
|
|
item["node_id"],
|
|
item["output"][:50] + "..." if len(item["output"]) > 50 else item["output"],
|
|
)
|
|
for item in node_outputs
|
|
]
|
|
},
|
|
)
|
|
return majority_output
|