ChatDev/workflow/executor/resource_manager.py
2026-01-07 16:24:01 +08:00

83 lines
2.7 KiB
Python
Executable File

"""Resource coordination helpers for workflow node execution."""
import threading
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, Iterable, List, Tuple
from entity.configs import Node
from runtime.node.registry import get_node_registration
from utils.log_manager import LogManager
@dataclass(frozen=True, slots=True)
class ResourceRequest:
"""Represents a single resource requirement."""
key: str
limit: int
@dataclass(slots=True)
class _ResourceSlot:
semaphore: threading.Semaphore
limit: int
class ResourceManager:
"""Coordinates shared resource usage across nodes."""
def __init__(self, log_manager: LogManager | None = None):
self.log_manager = log_manager
self._lock = threading.Lock()
self._resources: Dict[str, _ResourceSlot] = {}
@contextmanager
def guard_node(self, node: Node):
"""Acquire all resources required by the given node."""
requests = self._resolve_node_requests(node)
with self._acquire_resources(requests):
yield
def _resolve_node_requests(self, node: Node) -> List[ResourceRequest]:
registration = get_node_registration(node.node_type)
caps = registration.capabilities
requests: List[ResourceRequest] = []
key = caps.resource_key
limit = caps.resource_limit
if key and limit and limit > 0:
requests.append(ResourceRequest(key=key, limit=limit))
return requests
@contextmanager
def _acquire_resources(self, requests: Iterable[ResourceRequest]):
acquired: List[Tuple[str, threading.Semaphore]] = []
try:
for request in sorted(requests, key=lambda item: item.key):
semaphore = self._get_or_create_resource(request)
self._log_debug(f"Acquiring resource {request.key}")
semaphore.acquire()
acquired.append((request.key, semaphore))
yield
finally:
for key, semaphore in reversed(acquired):
semaphore.release()
self._log_debug(f"Released resource {key}")
def _get_or_create_resource(self, request: ResourceRequest) -> threading.Semaphore:
with self._lock:
slot = self._resources.get(request.key)
if slot and slot.limit != request.limit:
slot = None
if not slot:
slot = _ResourceSlot(
semaphore=threading.Semaphore(request.limit),
limit=request.limit,
)
self._resources[request.key] = slot
return slot.semaphore
def _log_debug(self, message: str) -> None:
if self.log_manager:
self.log_manager.debug(message)