ChatDev/workflow/hooks/workspace_artifact.py
2026-01-07 16:24:01 +08:00

275 lines
9.2 KiB
Python
Executable File

"""Hook that scans a node workspace for newly created files."""
import hashlib
import logging
import mimetypes
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
from entity.configs import Node
from entity.messages import MessageBlockType
from utils.attachments import AttachmentRecord, AttachmentStore
from utils.human_prompt import PromptChannel
@dataclass
class WorkspaceArtifact:
"""Represents a file artifact detected by the workspace hook."""
node_id: str
attachment_id: str
file_name: str
relative_path: str
absolute_path: str
mime_type: Optional[str]
size: Optional[int]
sha256: Optional[str]
data_uri: Optional[str]
created_at: float
change_type: str
extra: Dict[str, object]
@dataclass
class _FileSignature:
sha256: str
size: int
@dataclass
class _TrackedEntry:
sha256: str
attachment_id: str
absolute_path: str
mime_type: Optional[str]
size: Optional[int]
data_uri: Optional[str]
class WorkspaceArtifactHook:
"""Detects workspace file changes for selected node types."""
def __init__(
self,
*,
attachment_store: AttachmentStore,
emit_callback: Callable[[List[WorkspaceArtifact]], None],
node_types: Optional[Sequence[str]] = None,
exclude_dirs: Optional[Sequence[str]] = None,
max_files_scanned: int = 500,
max_bytes_scanned: int = 500 * 1024 * 1024,
prompt_channel: Optional[PromptChannel] = None,
) -> None:
self.attachment_store = attachment_store
self.emit_callback = emit_callback
self.node_types: Set[str] = set(node_types or {"python", "agent"})
self.exclude_dirs = set(exclude_dirs or {"attachments", "__pycache__"})
self.max_files_scanned = max_files_scanned
self.max_bytes_scanned = max_bytes_scanned
self.logger = logging.getLogger(__name__)
self._snapshots: Dict[str, Dict[str, _FileSignature]] = {}
self._last_emitted: Dict[str, _TrackedEntry] = {}
self.prompt_channel = prompt_channel
def can_handle(self, node: Node) -> bool:
return node.node_type in self.node_types
def get_prompt_channel(self) -> Optional[PromptChannel]:
return self.prompt_channel
def before_node(self, node: Node, workspace: Path) -> None:
if not self.can_handle(node):
return
snapshot, _ = self._snapshot(workspace)
self._snapshots[node.id] = snapshot
def after_node(
self,
node: Node,
workspace: Path,
*,
success: bool,
) -> None:
if not success or not self.can_handle(node):
self._snapshots.pop(node.id, None)
return
before = self._snapshots.pop(node.id, {})
after, truncated = self._snapshot(workspace)
if not after and not self._last_emitted:
return
changed_paths = [
Path(path_str)
for path_str, signature in after.items()
if path_str not in before or before[path_str].sha256 != signature.sha256
]
artifacts: List[WorkspaceArtifact] = []
for relative_path in changed_paths:
signature = after[str(relative_path)]
full_path = workspace / relative_path
if not full_path.exists() or not full_path.is_file():
continue
try:
tracked = self._last_emitted.get(str(relative_path))
change_type = "created" if tracked is None else "updated"
record = self._register_artifact(
full_path,
relative_path,
node,
attachment_id=tracked.attachment_id if tracked else None,
)
except Exception as exc:
self.logger.warning(
"Failed to register artifact %s for node %s: %s",
relative_path,
node.id,
exc,
)
continue
artifacts.append(
self._to_artifact(
record,
node,
relative_path,
full_path,
change_type=change_type,
)
)
self._last_emitted[str(relative_path)] = _TrackedEntry(
sha256=signature.sha256,
attachment_id=record.ref.attachment_id or "",
absolute_path=str(full_path),
mime_type=record.ref.mime_type,
size=record.ref.size,
data_uri=record.ref.data_uri,
)
if not truncated:
deleted_paths = [
relative_path
for relative_path in list(self._last_emitted.keys())
if relative_path not in after
]
for relative_path in deleted_paths:
tracked = self._last_emitted.pop(relative_path, None)
if not tracked:
continue
artifacts.append(
WorkspaceArtifact(
node_id=node.id,
attachment_id=tracked.attachment_id,
file_name=Path(relative_path).name,
relative_path=relative_path,
absolute_path=tracked.absolute_path,
mime_type=tracked.mime_type,
size=tracked.size,
sha256=tracked.sha256,
data_uri=tracked.data_uri,
created_at=time.time(),
change_type="deleted",
extra={
"hook": "workspace_scan",
"relative_path": relative_path,
},
)
)
if artifacts:
self.emit_callback(artifacts)
def _snapshot(self, workspace: Path) -> Tuple[Dict[str, _FileSignature], bool]:
entries: Dict[str, _FileSignature] = {}
total_bytes = 0
file_count = 0
for root, dirs, files in os.walk(workspace):
rel_root = Path(root).relative_to(workspace)
dirs[:] = [d for d in dirs if not self._is_excluded(rel_root / d)]
for filename in files:
rel_path = rel_root / filename
if self._is_excluded(rel_path):
continue
full_path = Path(root) / filename
try:
stat = full_path.stat()
sha256 = self._hash_file(full_path)
except OSError:
continue
file_count += 1
total_bytes += stat.st_size
entries[str(rel_path)] = _FileSignature(sha256=sha256, size=stat.st_size)
if file_count >= self.max_files_scanned or total_bytes >= self.max_bytes_scanned:
self.logger.warning(
"Workspace scan truncated (files=%s total_bytes=%s) for session %s",
file_count,
total_bytes,
)
return entries, True
return entries, False
def _is_excluded(self, rel_path: Path) -> bool:
if not rel_path.parts:
return False
return rel_path.parts[0] in self.exclude_dirs
def _register_artifact(
self,
full_path: Path,
relative_path: Path,
node: Node,
*,
attachment_id: Optional[str] = None,
) -> AttachmentRecord:
mime_type = mimetypes.guess_type(relative_path.name)[0] or "application/octet-stream"
return self.attachment_store.register_file(
full_path,
kind=MessageBlockType.from_mime_type(mime_type),
mime_type=mime_type,
display_name=full_path.name,
copy_file=False,
persist=True,
deduplicate=False,
attachment_id=attachment_id,
extra={
"node_id": node.id,
"relative_path": str(relative_path),
"hook": "workspace_scan",
},
)
def _to_artifact(
self,
record: AttachmentRecord,
node: Node,
relative_path: Path,
full_path: Path,
*,
change_type: str,
) -> WorkspaceArtifact:
ref = record.ref
return WorkspaceArtifact(
node_id=node.id,
attachment_id=ref.attachment_id or "",
file_name=ref.name or full_path.name,
relative_path=str(relative_path),
absolute_path=str(full_path),
mime_type=ref.mime_type,
size=ref.size,
sha256=ref.sha256,
data_uri=ref.data_uri,
created_at=time.time(),
change_type=change_type,
extra=dict(record.extra),
)
def _hash_file(self, path: Path) -> str:
hasher = hashlib.sha256()
with path.open("rb") as handle:
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
hasher.update(chunk)
return hasher.hexdigest()