mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
- Replaced `fetchWithAuth` with a generic `fetch` function across various API modules for consistency. - Updated `useThreadStream` and `useThreadHistory` hooks to manage chat history loading, including loading states and pagination. - Introduced `LoadMoreHistoryIndicator` component for better user experience when loading more chat history. - Enhanced message handling in `MessageList` to accommodate new loading states and history management. - Added support for run messages in the thread context, improving the overall message handling logic. - Updated translations for loading indicators in English and Chinese.
375 lines
15 KiB
Python
375 lines
15 KiB
Python
"""Run event capture via LangChain callbacks.
|
|
|
|
RunJournal sits between LangChain's callback mechanism and the pluggable
|
|
RunEventStore. It standardizes callback data into RunEvent records and
|
|
handles token usage accumulation.
|
|
|
|
Key design decisions:
|
|
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
|
- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and
|
|
extracts the first human message for run.input, because it is more reliable than
|
|
on_chain_start (fires on every node) — messages here are fully structured.
|
|
- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation.
|
|
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
|
- Token usage accumulated in memory, written to RunRow on run completion
|
|
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from datetime import UTC, datetime
|
|
from typing import TYPE_CHECKING, Any, cast
|
|
from uuid import UUID
|
|
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
|
from langgraph.types import Command
|
|
|
|
if TYPE_CHECKING:
|
|
from deerflow.runtime.events.store.base import RunEventStore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RunJournal(BaseCallbackHandler):
|
|
"""LangChain callback handler that captures events to RunEventStore."""
|
|
|
|
def __init__(
|
|
self,
|
|
run_id: str,
|
|
thread_id: str,
|
|
event_store: RunEventStore,
|
|
*,
|
|
track_token_usage: bool = True,
|
|
flush_threshold: int = 20,
|
|
):
|
|
super().__init__()
|
|
self.run_id = run_id
|
|
self.thread_id = thread_id
|
|
self._store = event_store
|
|
self._track_tokens = track_token_usage
|
|
self._flush_threshold = flush_threshold
|
|
|
|
# Write buffer
|
|
self._buffer: list[dict] = []
|
|
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
|
|
|
# Token accumulators
|
|
self._total_input_tokens = 0
|
|
self._total_output_tokens = 0
|
|
self._total_tokens = 0
|
|
self._llm_call_count = 0
|
|
self._lead_agent_tokens = 0
|
|
self._subagent_tokens = 0
|
|
self._middleware_tokens = 0
|
|
|
|
# Convenience fields
|
|
self._last_ai_msg: str | None = None
|
|
self._first_human_msg: str | None = None
|
|
self._msg_count = 0
|
|
|
|
# Latency tracking
|
|
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
|
|
|
|
# LLM request/response tracking
|
|
self._llm_call_index = 0
|
|
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
|
|
|
# -- Lifecycle callbacks --
|
|
|
|
def on_chain_start(
|
|
self,
|
|
serialized: dict[str, Any],
|
|
inputs: dict[str, Any],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: UUID | None = None,
|
|
tags: list[str] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
caller = self._identify_caller(tags)
|
|
if parent_run_id is None:
|
|
# Root graph invocation — emit a single trace event for the run start.
|
|
chain_name = (serialized or {}).get("name", "unknown")
|
|
self._put(
|
|
event_type="run.start",
|
|
category="trace",
|
|
content={"chain": chain_name},
|
|
metadata={"caller": caller, **(metadata or {})},
|
|
)
|
|
|
|
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
|
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
|
self._flush_sync()
|
|
|
|
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
|
self._put(
|
|
event_type="run.error",
|
|
category="error",
|
|
content=str(error),
|
|
metadata={"error_type": type(error).__name__},
|
|
)
|
|
self._flush_sync()
|
|
|
|
# -- LLM callbacks --
|
|
|
|
def on_chat_model_start(
|
|
self,
|
|
serialized: dict,
|
|
messages: list[list[BaseMessage]],
|
|
*,
|
|
run_id: UUID,
|
|
tags: list[str] | None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Capture structured prompt messages for llm_request event.
|
|
|
|
This is also the canonical place to extract the first human message:
|
|
messages are fully structured here, it fires only on real LLM calls,
|
|
and the content is never compressed by checkpoint trimming.
|
|
"""
|
|
rid = str(run_id)
|
|
self._llm_start_times[rid] = time.monotonic()
|
|
self._llm_call_index += 1
|
|
# Mark this run_id as seen so on_llm_end knows not to increment again.
|
|
self._cached_prompts[rid] = []
|
|
|
|
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
|
|
|
|
# Capture the first human message sent to any LLM in this run.
|
|
if not self._first_human_msg:
|
|
for batch in messages.reversed():
|
|
for m in batch.reversed():
|
|
if isinstance(m, HumanMessage) and m.name != "summary":
|
|
caller = self._identify_caller(tags)
|
|
self.set_first_human_message(m.text)
|
|
self._put(
|
|
event_type="llm.human.input",
|
|
category="message",
|
|
content=m.model_dump(),
|
|
metadata={"caller": caller},
|
|
)
|
|
break
|
|
if self._first_human_msg:
|
|
break
|
|
|
|
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
|
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
|
self._llm_start_times[str(run_id)] = time.monotonic()
|
|
|
|
def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None:
|
|
messages: list[AnyMessage] = []
|
|
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
|
|
for generation in response.generations:
|
|
for gen in generation:
|
|
if hasattr(gen, "message"):
|
|
messages.append(gen.message)
|
|
else:
|
|
logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}")
|
|
|
|
for message in messages:
|
|
caller = self._identify_caller(tags)
|
|
|
|
# Latency
|
|
rid = str(run_id)
|
|
start = self._llm_start_times.pop(rid, None)
|
|
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
|
|
|
# Token usage from message
|
|
usage = getattr(message, "usage_metadata", None)
|
|
usage_dict = dict(usage) if usage else {}
|
|
|
|
# Resolve call index
|
|
call_index = self._llm_call_index
|
|
if rid not in self._cached_prompts:
|
|
# Fallback: on_chat_model_start was not called
|
|
self._llm_call_index += 1
|
|
call_index = self._llm_call_index
|
|
|
|
# Trace event: llm_response (OpenAI completion format)
|
|
self._put(
|
|
event_type="llm.ai.response",
|
|
category="message",
|
|
content=message.model_dump(),
|
|
metadata={
|
|
"caller": caller,
|
|
"usage": usage_dict,
|
|
"latency_ms": latency_ms,
|
|
"llm_call_index": call_index,
|
|
},
|
|
)
|
|
|
|
# Token accumulation
|
|
if self._track_tokens:
|
|
input_tk = usage_dict.get("input_tokens", 0) or 0
|
|
output_tk = usage_dict.get("output_tokens", 0) or 0
|
|
total_tk = usage_dict.get("total_tokens", 0) or 0
|
|
if total_tk == 0:
|
|
total_tk = input_tk + output_tk
|
|
if total_tk > 0:
|
|
self._total_input_tokens += input_tk
|
|
self._total_output_tokens += output_tk
|
|
self._total_tokens += total_tk
|
|
self._llm_call_count += 1
|
|
|
|
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
|
self._llm_start_times.pop(str(run_id), None)
|
|
self._put(event_type="llm.error", category="trace", content=str(error))
|
|
|
|
def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs):
|
|
"""Handle tool start event, cache tool call ID for later correlation"""
|
|
tool_call_id = str(run_id)
|
|
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
|
|
|
|
def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs):
|
|
"""Handle tool end event, append message and clear node data"""
|
|
try:
|
|
if isinstance(output, ToolMessage):
|
|
msg = cast(ToolMessage, output)
|
|
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
|
elif isinstance(output, Command):
|
|
cmd = cast(Command, output)
|
|
messages = cmd.update.get("messages", [])
|
|
for message in messages:
|
|
if isinstance(message, BaseMessage):
|
|
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
|
else:
|
|
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
|
else:
|
|
logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}")
|
|
finally:
|
|
logger.info(f"Tool end for node {run_id}")
|
|
|
|
# -- Internal methods --
|
|
|
|
def _put(self, *, event_type: str, category: str, content: str | dict = "", metadata: dict | None = None) -> None:
|
|
self._buffer.append(
|
|
{
|
|
"thread_id": self.thread_id,
|
|
"run_id": self.run_id,
|
|
"event_type": event_type,
|
|
"category": category,
|
|
"content": content,
|
|
"metadata": metadata or {},
|
|
"created_at": datetime.now(UTC).isoformat(),
|
|
}
|
|
)
|
|
if len(self._buffer) >= self._flush_threshold:
|
|
self._flush_sync()
|
|
|
|
def _flush_sync(self) -> None:
|
|
"""Best-effort flush of buffer to RunEventStore.
|
|
|
|
BaseCallbackHandler methods are synchronous. If an event loop is
|
|
running we schedule an async ``put_batch``; otherwise the events
|
|
stay in the buffer and are flushed later by the async ``flush()``
|
|
call in the worker's ``finally`` block.
|
|
"""
|
|
if not self._buffer:
|
|
return
|
|
# Skip if a flush is already in flight — avoids concurrent writes
|
|
# to the same SQLite file from multiple fire-and-forget tasks.
|
|
if self._pending_flush_tasks:
|
|
return
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
# No event loop — keep events in buffer for later async flush.
|
|
return
|
|
batch = self._buffer.copy()
|
|
self._buffer.clear()
|
|
task = loop.create_task(self._flush_async(batch))
|
|
self._pending_flush_tasks.add(task)
|
|
task.add_done_callback(self._on_flush_done)
|
|
|
|
async def _flush_async(self, batch: list[dict]) -> None:
|
|
try:
|
|
await self._store.put_batch(batch)
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to flush %d events for run %s — returning to buffer",
|
|
len(batch),
|
|
self.run_id,
|
|
exc_info=True,
|
|
)
|
|
# Return failed events to buffer for retry on next flush
|
|
self._buffer = batch + self._buffer
|
|
|
|
def _on_flush_done(self, task: asyncio.Task) -> None:
|
|
self._pending_flush_tasks.discard(task)
|
|
if task.cancelled():
|
|
return
|
|
exc = task.exception()
|
|
if exc:
|
|
logger.warning("Journal flush task failed: %s", exc)
|
|
|
|
def _identify_caller(self, tags: list[str] | None, **kwargs) -> str:
|
|
_tags = tags or kwargs.get("tags", [])
|
|
for tag in _tags:
|
|
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
|
return tag
|
|
# Default to lead_agent: the main agent graph does not inject
|
|
# callback tags, while subagents and middleware explicitly tag
|
|
# themselves.
|
|
return "lead_agent"
|
|
|
|
# -- Public methods (called by worker) --
|
|
|
|
def set_first_human_message(self, content: str) -> None:
|
|
"""Record the first human message for convenience fields."""
|
|
self._first_human_msg = content[:2000] if content else None
|
|
|
|
def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
|
|
"""Record a middleware state-change event.
|
|
|
|
Called by middleware implementations when they perform a meaningful
|
|
state change (e.g., title generation, summarization, HITL approval).
|
|
Pure-observation middleware should not call this.
|
|
|
|
Args:
|
|
tag: Short identifier for the middleware (e.g., "title", "summarize",
|
|
"guardrail"). Used to form event_type="middleware:{tag}".
|
|
name: Full middleware class name.
|
|
hook: Lifecycle hook that triggered the action (e.g., "after_model").
|
|
action: Specific action performed (e.g., "generate_title").
|
|
changes: Dict describing the state changes made.
|
|
"""
|
|
self._put(
|
|
event_type=f"middleware:{tag}",
|
|
category="middleware",
|
|
content={"name": name, "hook": hook, "action": action, "changes": changes},
|
|
)
|
|
|
|
async def flush(self) -> None:
|
|
"""Force flush remaining buffer. Called in worker's finally block."""
|
|
if self._pending_flush_tasks:
|
|
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
|
|
|
while self._buffer:
|
|
batch = self._buffer[: self._flush_threshold]
|
|
del self._buffer[: self._flush_threshold]
|
|
try:
|
|
await self._store.put_batch(batch)
|
|
except Exception:
|
|
self._buffer = batch + self._buffer
|
|
raise
|
|
|
|
def get_completion_data(self) -> dict:
|
|
"""Return accumulated token and message data for run completion."""
|
|
return {
|
|
"total_input_tokens": self._total_input_tokens,
|
|
"total_output_tokens": self._total_output_tokens,
|
|
"total_tokens": self._total_tokens,
|
|
"llm_call_count": self._llm_call_count,
|
|
"lead_agent_tokens": self._lead_agent_tokens,
|
|
"subagent_tokens": self._subagent_tokens,
|
|
"middleware_tokens": self._middleware_tokens,
|
|
"message_count": self._msg_count,
|
|
"last_ai_message": self._last_ai_msg,
|
|
"first_human_message": self._first_human_msg,
|
|
}
|