mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-14 04:33:45 +00:00
* feat: real-time subagent token usage display in header and per-turn Backend: - Persist subagent token usage to AIMessage.usage_metadata via TokenUsageMiddleware, so accumulateUsage() naturally includes subagent tokens without frontend state management - Cache subagent usage by tool_call_id in task_tool, write back to the dispatching AIMessage on next model response - Emit subagent token usage on all terminal task events (task_completed, task_failed, task_cancelled, task_timed_out) - Report subagent usage to parent RunJournal for API totals - Search backward from ToolMessage to find dispatching AIMessage for correct multi-tool-call attribution Frontend: - Remove subagentUsage state, custom event handling, and prop threading — subagent tokens are now embedded in message metadata - Simplify selectHeaderTokenUsage (no subagentUsage parameter) - Per-turn inline badges show turn-specific usage via message accumulation - Remove isLoading guard from MessageTokenUsageList for dynamic updates during streaming * fix: prevent header token double counting from baseline reset race onFinish, onError, and thread-switch useEffect all reset pendingUsageBaselineMessageIdsRef to an empty Set. If thread.isLoading is still true on the next render, all messages pass the getMessagesAfterBaseline filter and their tokens are added to backendUsage (which already includes them), causing the header to display up to 2× the actual token count. Capture current message IDs instead of using an empty Set so that getMessagesAfterBaseline correctly returns no pending messages even if thread.isLoading lags behind the stream end. * fix: write back subagent tokens for all concurrent task tool calls TokenUsageMiddleware only processed messages[-2], so when a single model response dispatched multiple task tool calls only the last ToolMessage had its cached subagent usage written back to the dispatch AIMessage.usage_metadata. Earlier tasks' usage stayed in _subagent_usage_cache indefinitely (leak) and never appeared in the per-turn inline token display. Walk backward through all consecutive ToolMessages before the new AIMessage, and accumulate updates targeting the same dispatch message into one state update so overlapping writes don't clobber each other. * fix: clean up subagent usage cache entry on task cancellation When a task_tool invocation is cancelled via CancelledError, any cached subagent usage entry leaked because the TokenUsageMiddleware writeback path never fires after cancellation. Pop the cache entry before re-raising to prevent unbounded growth of the module-level _subagent_usage_cache dict. * fix: address token usage review feedback * fix: handle missing config for subagent usage cache --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
423 lines
19 KiB
Python
423 lines
19 KiB
Python
"""Task tool for delegating work to subagents."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import uuid
|
|
from dataclasses import replace
|
|
from typing import TYPE_CHECKING, Annotated, Any, cast
|
|
|
|
from langchain.tools import InjectedToolCallId, tool
|
|
from langgraph.config import get_stream_writer
|
|
|
|
from deerflow.config import get_app_config
|
|
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
|
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
|
from deerflow.subagents.config import resolve_subagent_model_name
|
|
from deerflow.subagents.executor import (
|
|
SubagentStatus,
|
|
cleanup_background_task,
|
|
get_background_task_result,
|
|
request_cancel_background_task,
|
|
)
|
|
from deerflow.tools.types import Runtime
|
|
|
|
if TYPE_CHECKING:
|
|
from deerflow.config.app_config import AppConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can
|
|
# write it back to the triggering AIMessage's usage_metadata.
|
|
_subagent_usage_cache: dict[str, dict[str, int]] = {}
|
|
|
|
|
|
def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
|
|
if app_config is None:
|
|
try:
|
|
app_config = get_app_config()
|
|
except FileNotFoundError:
|
|
return False
|
|
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
|
|
|
|
|
|
def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None:
|
|
if enabled and usage:
|
|
_subagent_usage_cache[tool_call_id] = usage
|
|
|
|
|
|
def pop_cached_subagent_usage(tool_call_id: str) -> dict | None:
|
|
return _subagent_usage_cache.pop(tool_call_id, None)
|
|
|
|
|
|
def _is_subagent_terminal(result: Any) -> bool:
|
|
"""Return whether a background subagent result is safe to clean up."""
|
|
return result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None
|
|
|
|
|
|
async def _await_subagent_terminal(task_id: str, max_polls: int) -> Any | None:
|
|
"""Poll until the background subagent reaches a terminal status or we run out of polls."""
|
|
for _ in range(max_polls):
|
|
result = get_background_task_result(task_id)
|
|
if result is None:
|
|
return None
|
|
if _is_subagent_terminal(result):
|
|
return result
|
|
await asyncio.sleep(5)
|
|
return None
|
|
|
|
|
|
async def _deferred_cleanup_subagent_task(task_id: str, trace_id: str, max_polls: int) -> None:
|
|
"""Keep polling a cancelled subagent until it can be safely removed."""
|
|
cleanup_poll_count = 0
|
|
while True:
|
|
result = get_background_task_result(task_id)
|
|
if result is None:
|
|
return
|
|
if _is_subagent_terminal(result):
|
|
cleanup_background_task(task_id)
|
|
return
|
|
if cleanup_poll_count >= max_polls:
|
|
logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls")
|
|
return
|
|
await asyncio.sleep(5)
|
|
cleanup_poll_count += 1
|
|
|
|
|
|
def _log_cleanup_failure(cleanup_task: asyncio.Task[None], *, trace_id: str, task_id: str) -> None:
|
|
if cleanup_task.cancelled():
|
|
return
|
|
|
|
exc = cleanup_task.exception()
|
|
if exc is not None:
|
|
logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}")
|
|
|
|
|
|
def _schedule_deferred_subagent_cleanup(task_id: str, trace_id: str, max_polls: int) -> None:
|
|
logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}")
|
|
cleanup_task = asyncio.create_task(_deferred_cleanup_subagent_task(task_id, trace_id, max_polls))
|
|
cleanup_task.add_done_callback(lambda task: _log_cleanup_failure(task, trace_id=trace_id, task_id=task_id))
|
|
|
|
|
|
def _find_usage_recorder(runtime: Any) -> Any | None:
|
|
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config."""
|
|
if runtime is None:
|
|
return None
|
|
config = getattr(runtime, "config", None)
|
|
if not isinstance(config, dict):
|
|
return None
|
|
callbacks = config.get("callbacks", [])
|
|
if not callbacks:
|
|
return None
|
|
for cb in callbacks:
|
|
if hasattr(cb, "record_external_llm_usage_records"):
|
|
return cb
|
|
return None
|
|
|
|
|
|
def _summarize_usage(records: list[dict] | None) -> dict | None:
|
|
"""Summarize token usage records into a compact dict for SSE events."""
|
|
if not records:
|
|
return None
|
|
return {
|
|
"input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records),
|
|
"output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records),
|
|
"total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records),
|
|
}
|
|
|
|
|
|
def _report_subagent_usage(runtime: Any, result: Any) -> None:
|
|
"""Report subagent token usage to the parent RunJournal, if available.
|
|
|
|
Each subagent task must be reported only once (guarded by usage_reported).
|
|
"""
|
|
if getattr(result, "usage_reported", True):
|
|
return
|
|
records = getattr(result, "token_usage_records", None) or []
|
|
if not records:
|
|
return
|
|
journal = _find_usage_recorder(runtime)
|
|
if journal is None:
|
|
logger.debug("No usage recorder found in runtime callbacks — subagent token usage not recorded")
|
|
return
|
|
try:
|
|
journal.record_external_llm_usage_records(records)
|
|
result.usage_reported = True
|
|
except Exception:
|
|
logger.warning("Failed to report subagent token usage", exc_info=True)
|
|
|
|
|
|
def _get_runtime_app_config(runtime: Any) -> "AppConfig | None":
|
|
context = getattr(runtime, "context", None)
|
|
if isinstance(context, dict):
|
|
app_config = context.get("app_config")
|
|
if app_config is not None:
|
|
return cast("AppConfig", app_config)
|
|
return None
|
|
|
|
|
|
def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) -> list[str] | None:
|
|
"""Return the effective subagent skill allowlist under the parent policy."""
|
|
if parent is None:
|
|
return child
|
|
if child is None:
|
|
return list(parent)
|
|
|
|
parent_set = set(parent)
|
|
return [skill for skill in child if skill in parent_set]
|
|
|
|
|
|
@tool("task", parse_docstring=True)
|
|
async def task_tool(
|
|
runtime: Runtime,
|
|
description: str,
|
|
prompt: str,
|
|
subagent_type: str,
|
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
) -> str:
|
|
"""Delegate a task to a specialized subagent that runs in its own context.
|
|
|
|
Subagents help you:
|
|
- Preserve context by keeping exploration and implementation separate
|
|
- Handle complex multi-step tasks autonomously
|
|
- Execute commands or operations in isolated contexts
|
|
|
|
Built-in subagent types:
|
|
- **general-purpose**: A capable agent for complex, multi-step tasks that require
|
|
both exploration and action. Use when the task requires complex reasoning,
|
|
multiple dependent steps, or would benefit from isolated context.
|
|
- **bash**: Command execution specialist for running bash commands. This is only
|
|
available when host bash is explicitly allowed or when using an isolated shell
|
|
sandbox such as `AioSandboxProvider`.
|
|
|
|
Additional custom subagent types may be defined in config.yaml under
|
|
`subagents.custom_agents`. Each custom type can have its own system prompt,
|
|
tools, skills, model, and timeout configuration. If an unknown subagent_type
|
|
is provided, the error message will list all available types.
|
|
|
|
When to use this tool:
|
|
- Complex tasks requiring multiple steps or tools
|
|
- Tasks that produce verbose output
|
|
- When you want to isolate context from the main conversation
|
|
- Parallel research or exploration tasks
|
|
|
|
When NOT to use this tool:
|
|
- Simple, single-step operations (use tools directly)
|
|
- Tasks requiring user interaction or clarification
|
|
|
|
Args:
|
|
description: A short (3-5 word) description of the task for logging/display. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
|
prompt: The task description for the subagent. Be specific and clear about what needs to be done. ALWAYS PROVIDE THIS PARAMETER SECOND.
|
|
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
|
"""
|
|
runtime_app_config = _get_runtime_app_config(runtime)
|
|
cache_token_usage = _token_usage_cache_enabled(runtime_app_config)
|
|
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
|
|
|
|
# Get subagent configuration
|
|
config = get_subagent_config(subagent_type, app_config=runtime_app_config) if runtime_app_config is not None else get_subagent_config(subagent_type)
|
|
if config is None:
|
|
available = ", ".join(available_subagent_names)
|
|
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
|
|
if subagent_type == "bash":
|
|
host_bash_allowed = is_host_bash_allowed(runtime_app_config) if runtime_app_config is not None else is_host_bash_allowed()
|
|
if not host_bash_allowed:
|
|
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
|
|
|
|
# Build config overrides
|
|
overrides: dict = {}
|
|
|
|
# Skills are loaded by SubagentExecutor per-session (aligned with Codex's pattern:
|
|
# each subagent loads its own skills based on config, injected as conversation items).
|
|
# No longer appended to system_prompt here.
|
|
|
|
# Extract parent context from runtime
|
|
sandbox_state = None
|
|
thread_data = None
|
|
thread_id = None
|
|
parent_model = None
|
|
trace_id = None
|
|
metadata: dict = {}
|
|
|
|
if runtime is not None:
|
|
sandbox_state = runtime.state.get("sandbox")
|
|
thread_data = runtime.state.get("thread_data")
|
|
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
|
if thread_id is None:
|
|
thread_id = runtime.config.get("configurable", {}).get("thread_id")
|
|
|
|
# Try to get parent model from configurable
|
|
metadata = runtime.config.get("metadata", {})
|
|
parent_model = metadata.get("model_name")
|
|
|
|
# Get or generate trace_id for distributed tracing
|
|
trace_id = metadata.get("trace_id") or str(uuid.uuid4())[:8]
|
|
|
|
parent_available_skills = metadata.get("available_skills")
|
|
if parent_available_skills is not None:
|
|
overrides["skills"] = _merge_skill_allowlists(list(parent_available_skills), config.skills)
|
|
|
|
if overrides:
|
|
config = replace(config, **overrides)
|
|
|
|
# Get available tools (excluding task tool to prevent nesting)
|
|
# Lazy import to avoid circular dependency
|
|
from deerflow.tools import get_available_tools
|
|
|
|
# Inherit parent agent's tool_groups so subagents respect the same restrictions
|
|
parent_tool_groups = metadata.get("tool_groups")
|
|
resolved_app_config = runtime_app_config
|
|
if config.model == "inherit" and parent_model is None and resolved_app_config is None:
|
|
resolved_app_config = get_app_config()
|
|
effective_model = resolve_subagent_model_name(config, parent_model, app_config=resolved_app_config)
|
|
|
|
# Subagents should not have subagent tools enabled (prevent recursive nesting)
|
|
available_tools_kwargs = {
|
|
"model_name": effective_model,
|
|
"groups": parent_tool_groups,
|
|
"subagent_enabled": False,
|
|
}
|
|
if resolved_app_config is not None:
|
|
available_tools_kwargs["app_config"] = resolved_app_config
|
|
tools = get_available_tools(**available_tools_kwargs)
|
|
|
|
# Create executor
|
|
executor_kwargs = {
|
|
"config": config,
|
|
"tools": tools,
|
|
"parent_model": parent_model,
|
|
"sandbox_state": sandbox_state,
|
|
"thread_data": thread_data,
|
|
"thread_id": thread_id,
|
|
"trace_id": trace_id,
|
|
}
|
|
if resolved_app_config is not None:
|
|
executor_kwargs["app_config"] = resolved_app_config
|
|
executor = SubagentExecutor(**executor_kwargs)
|
|
|
|
# Start background execution (always async to prevent blocking)
|
|
# Use tool_call_id as task_id for better traceability
|
|
task_id = executor.execute_async(prompt, task_id=tool_call_id)
|
|
|
|
# Poll for task completion in backend (removes need for LLM to poll)
|
|
poll_count = 0
|
|
last_status = None
|
|
last_message_count = 0 # Track how many AI messages we've already sent
|
|
# Polling timeout: execution timeout + 60s buffer, checked every 5s
|
|
max_poll_count = (config.timeout_seconds + 60) // 5
|
|
|
|
logger.info(f"[trace={trace_id}] Started background task {task_id} (subagent={subagent_type}, timeout={config.timeout_seconds}s, polling_limit={max_poll_count} polls)")
|
|
|
|
writer = get_stream_writer()
|
|
# Send Task Started message'
|
|
writer({"type": "task_started", "task_id": task_id, "description": description})
|
|
|
|
try:
|
|
while True:
|
|
result = get_background_task_result(task_id)
|
|
|
|
if result is None:
|
|
logger.error(f"[trace={trace_id}] Task {task_id} not found in background tasks")
|
|
writer({"type": "task_failed", "task_id": task_id, "error": "Task disappeared from background tasks"})
|
|
cleanup_background_task(task_id)
|
|
return f"Error: Task {task_id} disappeared from background tasks"
|
|
|
|
# Log status changes for debugging
|
|
if result.status != last_status:
|
|
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
|
|
last_status = result.status
|
|
|
|
# Check for new AI messages and send task_running events
|
|
ai_messages = result.ai_messages or []
|
|
current_message_count = len(ai_messages)
|
|
if current_message_count > last_message_count:
|
|
# Send task_running event for each new message
|
|
for i in range(last_message_count, current_message_count):
|
|
message = ai_messages[i]
|
|
writer(
|
|
{
|
|
"type": "task_running",
|
|
"task_id": task_id,
|
|
"message": message,
|
|
"message_index": i + 1, # 1-based index for display
|
|
"total_messages": current_message_count,
|
|
}
|
|
)
|
|
logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}")
|
|
last_message_count = current_message_count
|
|
|
|
# Check if task completed, failed, or timed out
|
|
usage = _summarize_usage(getattr(result, "token_usage_records", None))
|
|
if result.status == SubagentStatus.COMPLETED:
|
|
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
|
_report_subagent_usage(runtime, result)
|
|
writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage})
|
|
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
|
cleanup_background_task(task_id)
|
|
return f"Task Succeeded. Result: {result.result}"
|
|
elif result.status == SubagentStatus.FAILED:
|
|
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
|
_report_subagent_usage(runtime, result)
|
|
writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage})
|
|
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
|
cleanup_background_task(task_id)
|
|
return f"Task failed. Error: {result.error}"
|
|
elif result.status == SubagentStatus.CANCELLED:
|
|
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
|
_report_subagent_usage(runtime, result)
|
|
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage})
|
|
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
|
|
cleanup_background_task(task_id)
|
|
return "Task cancelled by user."
|
|
elif result.status == SubagentStatus.TIMED_OUT:
|
|
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
|
_report_subagent_usage(runtime, result)
|
|
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage})
|
|
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
|
cleanup_background_task(task_id)
|
|
return f"Task timed out. Error: {result.error}"
|
|
|
|
# Still running, wait before next poll
|
|
await asyncio.sleep(5)
|
|
poll_count += 1
|
|
|
|
# Polling timeout as a safety net (in case thread pool timeout doesn't work)
|
|
# Set to execution timeout + 60s buffer, in 5s poll intervals
|
|
# This catches edge cases where the background task gets stuck
|
|
# Note: We don't call cleanup_background_task here because the task may
|
|
# still be running in the background. The cleanup will happen when the
|
|
# executor completes and sets a terminal status.
|
|
if poll_count > max_poll_count:
|
|
timeout_minutes = config.timeout_seconds // 60
|
|
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
|
_report_subagent_usage(runtime, result)
|
|
usage = _summarize_usage(getattr(result, "token_usage_records", None))
|
|
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
|
writer({"type": "task_timed_out", "task_id": task_id, "usage": usage})
|
|
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
|
except asyncio.CancelledError:
|
|
# Signal the background subagent thread to stop cooperatively.
|
|
request_cancel_background_task(task_id)
|
|
|
|
# Wait (shielded) for the subagent to reach a terminal state so the
|
|
# final token usage snapshot is reported to the parent RunJournal
|
|
# before the parent worker persists get_completion_data().
|
|
terminal_result = None
|
|
try:
|
|
terminal_result = await asyncio.shield(_await_subagent_terminal(task_id, max_poll_count))
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Report whatever the subagent collected (even if we timed out).
|
|
final_result = terminal_result or get_background_task_result(task_id)
|
|
if final_result is not None:
|
|
_report_subagent_usage(runtime, final_result)
|
|
if final_result is not None and _is_subagent_terminal(final_result):
|
|
cleanup_background_task(task_id)
|
|
else:
|
|
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
|
|
_subagent_usage_cache.pop(tool_call_id, None)
|
|
raise
|
|
except Exception:
|
|
_subagent_usage_cache.pop(tool_call_id, None)
|
|
raise
|