mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-26 03:38:12 +00:00
83 lines
2.7 KiB
Python
Executable File
83 lines
2.7 KiB
Python
Executable File
"""Resource coordination helpers for workflow node execution."""
|
|
|
|
import threading
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Iterable, List, Tuple
|
|
|
|
from entity.configs import Node
|
|
from runtime.node.registry import get_node_registration
|
|
from utils.log_manager import LogManager
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ResourceRequest:
|
|
"""Represents a single resource requirement."""
|
|
|
|
key: str
|
|
limit: int
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class _ResourceSlot:
|
|
semaphore: threading.Semaphore
|
|
limit: int
|
|
|
|
|
|
class ResourceManager:
|
|
"""Coordinates shared resource usage across nodes."""
|
|
|
|
def __init__(self, log_manager: LogManager | None = None):
|
|
self.log_manager = log_manager
|
|
self._lock = threading.Lock()
|
|
self._resources: Dict[str, _ResourceSlot] = {}
|
|
|
|
@contextmanager
|
|
def guard_node(self, node: Node):
|
|
"""Acquire all resources required by the given node."""
|
|
requests = self._resolve_node_requests(node)
|
|
with self._acquire_resources(requests):
|
|
yield
|
|
|
|
def _resolve_node_requests(self, node: Node) -> List[ResourceRequest]:
|
|
registration = get_node_registration(node.node_type)
|
|
caps = registration.capabilities
|
|
requests: List[ResourceRequest] = []
|
|
key = caps.resource_key
|
|
limit = caps.resource_limit
|
|
if key and limit and limit > 0:
|
|
requests.append(ResourceRequest(key=key, limit=limit))
|
|
return requests
|
|
|
|
@contextmanager
|
|
def _acquire_resources(self, requests: Iterable[ResourceRequest]):
|
|
acquired: List[Tuple[str, threading.Semaphore]] = []
|
|
try:
|
|
for request in sorted(requests, key=lambda item: item.key):
|
|
semaphore = self._get_or_create_resource(request)
|
|
self._log_debug(f"Acquiring resource {request.key}")
|
|
semaphore.acquire()
|
|
acquired.append((request.key, semaphore))
|
|
yield
|
|
finally:
|
|
for key, semaphore in reversed(acquired):
|
|
semaphore.release()
|
|
self._log_debug(f"Released resource {key}")
|
|
|
|
def _get_or_create_resource(self, request: ResourceRequest) -> threading.Semaphore:
|
|
with self._lock:
|
|
slot = self._resources.get(request.key)
|
|
if slot and slot.limit != request.limit:
|
|
slot = None
|
|
if not slot:
|
|
slot = _ResourceSlot(
|
|
semaphore=threading.Semaphore(request.limit),
|
|
limit=request.limit,
|
|
)
|
|
self._resources[request.key] = slot
|
|
return slot.semaphore
|
|
|
|
def _log_debug(self, message: str) -> None:
|
|
if self.log_manager:
|
|
self.log_manager.debug(message)
|