From df63c104a7a1e7db23e566739e417d1d1f1ddbed Mon Sep 17 00:00:00 2001 From: JeffJiang Date: Fri, 17 Apr 2026 23:41:11 +0800 Subject: [PATCH] Refactor API fetch calls to use a unified fetch function; enhance chat history loading with new hooks and UI components - Replaced `fetchWithAuth` with a generic `fetch` function across various API modules for consistency. - Updated `useThreadStream` and `useThreadHistory` hooks to manage chat history loading, including loading states and pagination. - Introduced `LoadMoreHistoryIndicator` component for better user experience when loading more chat history. - Enhanced message handling in `MessageList` to accommodate new loading states and history management. - Added support for run messages in the thread context, improving the overall message handling logic. - Updated translations for loading indicators in English and Chinese. --- backend/app/gateway/deps.py | 37 +- backend/app/gateway/routers/runs.py | 3 +- backend/app/gateway/routers/thread_runs.py | 18 +- backend/app/gateway/routers/uploads.py | 2 +- backend/app/gateway/services.py | 16 - .../deerflow/agents/lead_agent/agent.py | 3 +- .../middlewares/loop_detection_middleware.py | 2 +- .../middlewares/summarization_middleware.py | 13 + .../middlewares/thread_data_middleware.py | 16 +- .../agents/middlewares/uploads_middleware.py | 1 + .../harness/deerflow/runtime/journal.py | 411 ++++++------------ .../harness/deerflow/runtime/runs/manager.py | 11 +- .../deerflow/runtime/runs/store/base.py | 1 - .../deerflow/runtime/runs/store/memory.py | 2 - .../harness/deerflow/runtime/runs/worker.py | 21 +- .../[agent_name]/chats/[thread_id]/page.tsx | 11 +- .../app/workspace/chats/[thread_id]/page.tsx | 26 +- .../src/components/workspace/input-box.tsx | 38 +- .../workspace/messages/message-list.tsx | 128 +++++- .../settings/account-settings-page.tsx | 4 +- frontend/src/core/agents/api.ts | 8 +- frontend/src/core/api/feedback.ts | 6 +- frontend/src/core/api/fetcher.ts | 4 +- frontend/src/core/i18n/locales/en-US.ts | 1 + frontend/src/core/i18n/locales/types.ts | 1 + frontend/src/core/i18n/locales/zh-CN.ts | 1 + frontend/src/core/mcp/api.ts | 17 +- frontend/src/core/memory/api.ts | 38 +- frontend/src/core/messages/utils.ts | 6 +- frontend/src/core/skills/api.ts | 19 +- frontend/src/core/threads/hooks.ts | 403 +++++++++-------- frontend/src/core/threads/types.ts | 9 + frontend/src/core/uploads/api.ts | 6 +- 33 files changed, 665 insertions(+), 618 deletions(-) create mode 100644 backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index f4fdad473..20da78af9 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -8,13 +8,17 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`. from __future__ import annotations -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from contextlib import AsyncExitStack, asynccontextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar, cast from fastapi import FastAPI, HTTPException, Request +from langgraph.types import Checkpointer -from deerflow.runtime import RunContext, RunManager +from deerflow.persistence.feedback import FeedbackRepository +from deerflow.runtime import RunContext, RunManager, StreamBridge +from deerflow.runtime.events.store.base import RunEventStore +from deerflow.runtime.runs.store.base import RunStore if TYPE_CHECKING: from app.gateway.auth.local_provider import LocalAuthProvider @@ -22,6 +26,9 @@ if TYPE_CHECKING: from deerflow.persistence.thread_meta.base import ThreadMetaStore +T = TypeVar("T") + + @asynccontextmanager async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: """Bootstrap and tear down all LangGraph runtime singletons. @@ -84,25 +91,25 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: # --------------------------------------------------------------------------- -def _require(attr: str, label: str): +def _require(attr: str, label: str) -> Callable[[Request], T]: """Create a FastAPI dependency that returns ``app.state.`` or 503.""" - def dep(request: Request): + def dep(request: Request) -> T: val = getattr(request.app.state, attr, None) if val is None: raise HTTPException(status_code=503, detail=f"{label} not available") - return val + return cast(T, val) dep.__name__ = dep.__qualname__ = f"get_{attr}" return dep -get_stream_bridge = _require("stream_bridge", "Stream bridge") -get_run_manager = _require("run_manager", "Run manager") -get_checkpointer = _require("checkpointer", "Checkpointer") -get_run_event_store = _require("run_event_store", "Run event store") -get_feedback_repo = _require("feedback_repo", "Feedback") -get_run_store = _require("run_store", "Run store") +get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge") +get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager") +get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer") +get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store") +get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback") +get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store") def get_store(request: Request): @@ -121,10 +128,7 @@ def get_thread_store(request: Request) -> ThreadMetaStore: def get_run_context(request: Request) -> RunContext: """Build a :class:`RunContext` from ``app.state`` singletons. - Returns a *base* context with infrastructure dependencies. Callers that - need per-run fields (e.g. ``follow_up_to_run_id``) should use - ``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it - to :func:`run_agent`. + Returns a *base* context with infrastructure dependencies. """ from deerflow.config import get_app_config @@ -137,7 +141,6 @@ def get_run_context(request: Request) -> RunContext: ) - # --------------------------------------------------------------------------- # Auth helpers (used by authz.py and auth middleware) # --------------------------------------------------------------------------- diff --git a/backend/app/gateway/routers/runs.py b/backend/app/gateway/routers/runs.py index 70e2abb63..f2775466c 100644 --- a/backend/app/gateway/routers/runs.py +++ b/backend/app/gateway/routers/runs.py @@ -123,7 +123,8 @@ async def run_messages( run = await _resolve_run(run_id, request) event_store = get_run_event_store(request) rows = await event_store.list_messages_by_run( - run["thread_id"], run_id, + run["thread_id"], + run_id, limit=limit + 1, before_seq=before_seq, after_seq=after_seq, diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index e21375ab9..e6847c50f 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -54,7 +54,6 @@ class RunCreateRequest(BaseModel): after_seconds: float | None = Field(default=None, description="Delayed execution") if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy") feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys") - follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.") class RunResponse(BaseModel): @@ -312,11 +311,15 @@ async def list_thread_messages( if i in last_ai_indices: run_id = msg["run_id"] fb = feedback_map.get(run_id) - msg["feedback"] = { - "feedback_id": fb["feedback_id"], - "rating": fb["rating"], - "comment": fb.get("comment"), - } if fb else None + msg["feedback"] = ( + { + "feedback_id": fb["feedback_id"], + "rating": fb["rating"], + "comment": fb.get("comment"), + } + if fb + else None + ) else: msg["feedback"] = None @@ -339,7 +342,8 @@ async def list_run_messages( """ event_store = get_run_event_store(request) rows = await event_store.list_messages_by_run( - thread_id, run_id, + thread_id, + run_id, limit=limit + 1, before_seq=before_seq, after_seq=after_seq, diff --git a/backend/app/gateway/routers/uploads.py b/backend/app/gateway/routers/uploads.py index aa707e9ea..e31ff11d2 100644 --- a/backend/app/gateway/routers/uploads.py +++ b/backend/app/gateway/routers/uploads.py @@ -56,7 +56,7 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None: @router.post("", response_model=UploadResponse) -@require_permission("threads", "write", owner_check=True, require_existing=True) +@require_permission("threads", "write", owner_check=True, require_existing=False) async def upload_files( thread_id: str, request: Request, diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 72f074907..9a1cdb12f 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -195,21 +195,6 @@ async def start_run( disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_ - # Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run - follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None) - if follow_up_to_run_id is None: - run_store = get_run_store(request) - try: - recent_runs = await run_store.list_by_thread(thread_id, limit=1) - if recent_runs and recent_runs[0].get("status") == "success": - follow_up_to_run_id = recent_runs[0]["run_id"] - except Exception: - pass # Don't block run creation - - # Enrich base context with per-run field - if follow_up_to_run_id: - run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id) - try: record = await run_mgr.create_or_reject( thread_id, @@ -218,7 +203,6 @@ async def start_run( metadata=body.metadata or {}, kwargs={"input": body.input, "config": body.config}, multitask_strategy=body.multitask_strategy, - follow_up_to_run_id=follow_up_to_run_id, ) except ConflictError as exc: raise HTTPException(status_code=409, detail=str(exc)) from exc diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index 352ff021e..ef8224e6b 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -1,7 +1,7 @@ import logging from langchain.agents import create_agent -from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware +from langchain.agents.middleware import AgentMiddleware from langchain_core.runnables import RunnableConfig from deerflow.agents.lead_agent.prompt import apply_prompt_template @@ -9,6 +9,7 @@ from deerflow.agents.middlewares.clarification_middleware import ClarificationMi from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware +from deerflow.agents.middlewares.summarization_middleware import SummarizationMiddleware from deerflow.agents.middlewares.title_middleware import TitleMiddleware from deerflow.agents.middlewares.todo_middleware import TodoMiddleware from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware diff --git a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py index 0b161152c..1fdc01fcc 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py @@ -283,7 +283,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): # the conversation; injecting one mid-conversation crashes # langchain_anthropic's _format_messages(). HumanMessage works # with all providers. See #1299. - return {"messages": [HumanMessage(content=warning)]} + return {"messages": [HumanMessage(content=warning, name="loop_warning")]} return None diff --git a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py new file mode 100644 index 000000000..243cdb39f --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py @@ -0,0 +1,13 @@ +from typing import override + +from langchain.agents.middleware import SummarizationMiddleware as BaseSummarizationMiddleware +from langchain_core.messages.human import HumanMessage + + +class SummarizationMiddleware(BaseSummarizationMiddleware): + @override + def _build_new_messages(self, summary: str) -> list[HumanMessage]: + """Override the base implementation to let the human message with the special name 'summary'. + And this message will be ignored to display in the frontend, but still can be used as context for the model. + """ + return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")] diff --git a/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py index 828a82621..8d93de4ff 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py @@ -1,8 +1,10 @@ import logging +from datetime import UTC, datetime from typing import NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import HumanMessage from langgraph.config import get_config from langgraph.runtime import Runtime @@ -97,8 +99,20 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]): paths = self._create_thread_directories(thread_id, user_id=user_id) logger.debug("Created thread data directories for thread %s", thread_id) + messages = list(state.get("messages", [])) + last_message = messages[-1] if messages else None + + if last_message and isinstance(last_message, HumanMessage): + messages[-1] = HumanMessage( + content=last_message.content, + id=last_message.id, + name=last_message.name or "user-input", + additional_kwargs={**last_message.additional_kwargs, "run_id": runtime.context.get("run_id"), "timestamp": datetime.now(UTC).isoformat()}, + ) + return { "thread_data": { **paths, - } + }, + "messages": messages, } diff --git a/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py index 6622fb695..4f584c5c3 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py @@ -279,6 +279,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]): updated_message = HumanMessage( content=f"{files_message}\n\n{original_content}", id=last_message.id, + name=last_message.name, additional_kwargs=last_message.additional_kwargs, ) diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index a70404e11..5f1838888 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -6,7 +6,10 @@ handles token usage accumulation. Key design decisions: - on_llm_new_token is NOT implemented -- only complete messages via on_llm_end -- on_chat_model_start captures structured prompts as llm_request (OpenAI format) +- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and + extracts the first human message for run.input, because it is more reliable than + on_chain_start (fires on every node) — messages here are fully structured. +- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation. - on_llm_end emits llm_response in OpenAI Chat Completions format - Token usage accumulated in memory, written to RunRow on run completion - Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name}) @@ -18,10 +21,12 @@ import asyncio import logging import time from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from uuid import UUID from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage +from langgraph.types import Command if TYPE_CHECKING: from deerflow.runtime.events.store.base import RunEventStore @@ -72,34 +77,39 @@ class RunJournal(BaseCallbackHandler): # LLM request/response tracking self._llm_call_index = 0 self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages - self._cached_models: dict[str, str] = {} # langchain run_id -> model name - - # Tool call ID cache - self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id # -- Lifecycle callbacks -- - def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None: - if kwargs.get("parent_run_id") is not None: - return - self._put( - event_type="run_start", - category="lifecycle", - metadata={"input_preview": str(inputs)[:500]}, - ) + def on_chain_start( + self, + serialized: dict[str, Any], + inputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + caller = self._identify_caller(tags) + if parent_run_id is None: + # Root graph invocation — emit a single trace event for the run start. + chain_name = (serialized or {}).get("name", "unknown") + self._put( + event_type="run.start", + category="trace", + content={"chain": chain_name}, + metadata={"caller": caller, **(metadata or {})}, + ) def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None: - if kwargs.get("parent_run_id") is not None: - return - self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"}) + self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"}) self._flush_sync() def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: - if kwargs.get("parent_run_id") is not None: - return self._put( - event_type="run_error", - category="lifecycle", + event_type="run.error", + category="error", content=str(error), metadata={"error_type": type(error).__name__}, ) @@ -107,266 +117,132 @@ class RunJournal(BaseCallbackHandler): # -- LLM callbacks -- - def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None: - """Capture structured prompt messages for llm_request event.""" - from deerflow.runtime.converters import langchain_messages_to_openai + def on_chat_model_start( + self, + serialized: dict, + messages: list[list[BaseMessage]], + *, + run_id: UUID, + tags: list[str] | None = None, + **kwargs: Any, + ) -> None: + """Capture structured prompt messages for llm_request event. + This is also the canonical place to extract the first human message: + messages are fully structured here, it fires only on real LLM calls, + and the content is never compressed by checkpoint trimming. + """ rid = str(run_id) self._llm_start_times[rid] = time.monotonic() self._llm_call_index += 1 + # Mark this run_id as seen so on_llm_end knows not to increment again. + self._cached_prompts[rid] = [] - model_name = serialized.get("name", "") - self._cached_models[rid] = model_name + logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}") - # Convert the first message list (LangChain passes list-of-lists) - prompt_msgs = messages[0] if messages else [] - openai_msgs = langchain_messages_to_openai(prompt_msgs) - self._cached_prompts[rid] = openai_msgs + # Capture the first human message sent to any LLM in this run. + if not self._first_human_msg: + for batch in messages.reversed(): + for m in batch.reversed(): + if isinstance(m, HumanMessage) and m.name != "summary": + caller = self._identify_caller(tags) + self.set_first_human_message(m.text) + self._put( + event_type="llm.human.input", + category="message", + content=m.model_dump(), + metadata={"caller": caller}, + ) + break + if self._first_human_msg: + break - caller = self._identify_caller(kwargs) - self._put( - event_type="llm_request", - category="trace", - content={"model": model_name, "messages": openai_msgs}, - metadata={"caller": caller, "llm_call_index": self._llm_call_index}, - ) - - def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None: + def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None: # Fallback: on_chat_model_start is preferred. This just tracks latency. self._llm_start_times[str(run_id)] = time.monotonic() - def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None: - from deerflow.runtime.converters import langchain_to_openai_completion - - try: - message = response.generations[0][0].message - except (IndexError, AttributeError): - logger.debug("on_llm_end: could not extract message from response") - return - - caller = self._identify_caller(kwargs) - - # Latency - rid = str(run_id) - start = self._llm_start_times.pop(rid, None) - latency_ms = int((time.monotonic() - start) * 1000) if start else None - - # Token usage from message - usage = getattr(message, "usage_metadata", None) - usage_dict = dict(usage) if usage else {} - - # Resolve call index - call_index = self._llm_call_index - if rid not in self._cached_prompts: - # Fallback: on_chat_model_start was not called - self._llm_call_index += 1 - call_index = self._llm_call_index - - # Clean up caches - self._cached_prompts.pop(rid, None) - self._cached_models.pop(rid, None) - - # Trace event: llm_response (OpenAI completion format) - content = getattr(message, "content", "") - self._put( - event_type="llm_response", - category="trace", - content=langchain_to_openai_completion(message), - metadata={ - "caller": caller, - "usage": usage_dict, - "latency_ms": latency_ms, - "llm_call_index": call_index, - }, - ) - - # Message events: only lead_agent gets message-category events. - # Content uses message.model_dump() to align with checkpoint format. - tool_calls = getattr(message, "tool_calls", None) or [] - if caller == "lead_agent": - resp_meta = getattr(message, "response_metadata", None) or {} - model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None - if tool_calls: - # ai_tool_call: agent decided to use tools - self._put( - event_type="ai_tool_call", - category="message", - content=message.model_dump(), - metadata={"model_name": model_name, "finish_reason": "tool_calls"}, - ) - elif isinstance(content, str) and content: - # ai_message: final text reply - self._put( - event_type="ai_message", - category="message", - content=message.model_dump(), - metadata={"model_name": model_name, "finish_reason": "stop"}, - ) - self._last_ai_msg = content - self._msg_count += 1 - - # Token accumulation - if self._track_tokens: - input_tk = usage_dict.get("input_tokens", 0) or 0 - output_tk = usage_dict.get("output_tokens", 0) or 0 - total_tk = usage_dict.get("total_tokens", 0) or 0 - if total_tk == 0: - total_tk = input_tk + output_tk - if total_tk > 0: - self._total_input_tokens += input_tk - self._total_output_tokens += output_tk - self._total_tokens += total_tk - self._llm_call_count += 1 - if caller.startswith("subagent:"): - self._subagent_tokens += total_tk - elif caller.startswith("middleware:"): - self._middleware_tokens += total_tk + def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None: + messages: list[AnyMessage] = [] + logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}") + for generation in response.generations: + for gen in generation: + if hasattr(gen, "message"): + messages.append(gen.message) else: - self._lead_agent_tokens += total_tk + logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}") + + for message in messages: + caller = self._identify_caller(tags) + + # Latency + rid = str(run_id) + start = self._llm_start_times.pop(rid, None) + latency_ms = int((time.monotonic() - start) * 1000) if start else None + + # Token usage from message + usage = getattr(message, "usage_metadata", None) + usage_dict = dict(usage) if usage else {} + + # Resolve call index + call_index = self._llm_call_index + if rid not in self._cached_prompts: + # Fallback: on_chat_model_start was not called + self._llm_call_index += 1 + call_index = self._llm_call_index + + # Trace event: llm_response (OpenAI completion format) + self._put( + event_type="llm.ai.response", + category="message", + content=message.model_dump(), + metadata={ + "caller": caller, + "usage": usage_dict, + "latency_ms": latency_ms, + "llm_call_index": call_index, + }, + ) + + # Token accumulation + if self._track_tokens: + input_tk = usage_dict.get("input_tokens", 0) or 0 + output_tk = usage_dict.get("output_tokens", 0) or 0 + total_tk = usage_dict.get("total_tokens", 0) or 0 + if total_tk == 0: + total_tk = input_tk + output_tk + if total_tk > 0: + self._total_input_tokens += input_tk + self._total_output_tokens += output_tk + self._total_tokens += total_tk + self._llm_call_count += 1 def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: self._llm_start_times.pop(str(run_id), None) - self._put(event_type="llm_error", category="trace", content=str(error)) + self._put(event_type="llm.error", category="trace", content=str(error)) - # -- Tool callbacks -- + def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs): + """Handle tool start event, cache tool call ID for later correlation""" + tool_call_id = str(run_id) + logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}") - def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None: - tool_call_id = kwargs.get("tool_call_id") - if tool_call_id: - self._tool_call_ids[str(run_id)] = tool_call_id - self._put( - event_type="tool_start", - category="trace", - metadata={ - "tool_name": serialized.get("name", ""), - "tool_call_id": tool_call_id, - "args": str(input_str)[:2000], - }, - ) - - def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None: - from langchain_core.messages import ToolMessage - from langgraph.types import Command - - # Tools that update graph state return a ``Command`` (e.g. - # ``present_files``). LangGraph later unwraps the inner ToolMessage - # into checkpoint state, so to stay checkpoint-aligned we must - # extract it here rather than storing ``str(Command(...))``. - if isinstance(output, Command): - update = getattr(output, "update", None) or {} - inner_msgs = update.get("messages") if isinstance(update, dict) else None - if isinstance(inner_msgs, list): - inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None) - if inner_tool_msg is not None: - output = inner_tool_msg - - # Extract fields from ToolMessage object when LangChain provides one. - # LangChain's _format_output wraps tool results into a ToolMessage - # with tool_call_id, name, status, and artifact — more complete than - # what kwargs alone provides. - if isinstance(output, ToolMessage): - tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) - tool_name = output.name or kwargs.get("name", "") - status = getattr(output, "status", "success") or "success" - content_str = output.content if isinstance(output.content, str) else str(output.content) - # Use model_dump() for checkpoint-aligned message content. - # Override tool_call_id if it was resolved from cache. - msg_content = output.model_dump() - if msg_content.get("tool_call_id") != tool_call_id: - msg_content["tool_call_id"] = tool_call_id - else: - tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) - tool_name = kwargs.get("name", "") - status = "success" - content_str = str(output) - # Construct checkpoint-aligned dict when output is a plain string. - msg_content = ToolMessage( - content=content_str, - tool_call_id=tool_call_id or "", - name=tool_name, - status=status, - ).model_dump() - - # Trace event (always) - self._put( - event_type="tool_end", - category="trace", - content=content_str, - metadata={ - "tool_name": tool_name, - "tool_call_id": tool_call_id, - "status": status, - }, - ) - - # Message event: tool_result (checkpoint-aligned model_dump format) - self._put( - event_type="tool_result", - category="message", - content=msg_content, - metadata={"tool_name": tool_name, "status": status}, - ) - - def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: - from langchain_core.messages import ToolMessage - - tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) - tool_name = kwargs.get("name", "") - - # Trace event - self._put( - event_type="tool_error", - category="trace", - content=str(error), - metadata={ - "tool_name": tool_name, - "tool_call_id": tool_call_id, - }, - ) - - # Message event: tool_result with error status (checkpoint-aligned) - msg_content = ToolMessage( - content=str(error), - tool_call_id=tool_call_id or "", - name=tool_name, - status="error", - ).model_dump() - self._put( - event_type="tool_result", - category="message", - content=msg_content, - metadata={"tool_name": tool_name, "status": "error"}, - ) - - # -- Custom event callback -- - - def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None: - from deerflow.runtime.serialization import serialize_lc_object - - if name == "summarization": - data_dict = data if isinstance(data, dict) else {} - self._put( - event_type="summarization", - category="trace", - content=data_dict.get("summary", ""), - metadata={ - "replaced_message_ids": data_dict.get("replaced_message_ids", []), - "replaced_count": data_dict.get("replaced_count", 0), - }, - ) - self._put( - event_type="middleware:summarize", - category="middleware", - content={"role": "system", "content": data_dict.get("summary", "")}, - metadata={"replaced_count": data_dict.get("replaced_count", 0)}, - ) - else: - event_data = serialize_lc_object(data) if not isinstance(data, dict) else data - self._put( - event_type=name, - category="trace", - metadata=event_data if isinstance(event_data, dict) else {"data": event_data}, - ) + def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs): + """Handle tool end event, append message and clear node data""" + try: + if isinstance(output, ToolMessage): + msg = cast(ToolMessage, output) + self._put(event_type="llm.tool.result", category="message", content=msg.model_dump()) + elif isinstance(output, Command): + cmd = cast(Command, output) + messages = cmd.update.get("messages", []) + for message in messages: + if isinstance(message, BaseMessage): + self._put(event_type="llm.tool.result", category="message", content=message.model_dump()) + else: + logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}") + else: + logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}") + finally: + logger.info(f"Tool end for node {run_id}") # -- Internal methods -- @@ -431,8 +307,9 @@ class RunJournal(BaseCallbackHandler): if exc: logger.warning("Journal flush task failed: %s", exc) - def _identify_caller(self, kwargs: dict) -> str: - for tag in kwargs.get("tags") or []: + def _identify_caller(self, tags: list[str] | None, **kwargs) -> str: + _tags = tags or kwargs.get("tags", []) + for tag in _tags: if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"): return tag # Default to lead_agent: the main agent graph does not inject diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index 0a0794d87..a54a408b8 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -54,7 +54,7 @@ class RunManager: self._lock = asyncio.Lock() self._store = store - async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None: + async def _persist_to_store(self, record: RunRecord) -> None: """Best-effort persist run record to backing store.""" if self._store is None: return @@ -68,7 +68,6 @@ class RunManager: metadata=record.metadata or {}, kwargs=record.kwargs or {}, created_at=record.created_at, - follow_up_to_run_id=follow_up_to_run_id, ) except Exception: logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) @@ -90,7 +89,6 @@ class RunManager: metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", - follow_up_to_run_id: str | None = None, ) -> RunRecord: """Create a new pending run and register it.""" run_id = str(uuid.uuid4()) @@ -109,7 +107,7 @@ class RunManager: ) async with self._lock: self._runs[run_id] = record - await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id) + await self._persist_to_store(record) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record @@ -122,7 +120,7 @@ class RunManager: async with self._lock: # Dict insertion order matches creation order, so reversing it gives # us deterministic newest-first results even when timestamps tie. - return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id] + return [r for r in self._runs.values() if r.thread_id == thread_id] async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: """Transition a run to a new status.""" @@ -176,7 +174,6 @@ class RunManager: metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", - follow_up_to_run_id: str | None = None, ) -> RunRecord: """Atomically check for inflight runs and create a new one. @@ -230,7 +227,7 @@ class RunManager: ) self._runs[run_id] = record - await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id) + await self._persist_to_store(record) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index 3212e8ca3..518a1903c 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -29,7 +29,6 @@ class RunStore(abc.ABC): kwargs: dict[str, Any] | None = None, error: str | None = None, created_at: str | None = None, - follow_up_to_run_id: str | None = None, ) -> None: pass diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index 0b2b05f07..5a14af3df 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -28,7 +28,6 @@ class MemoryRunStore(RunStore): kwargs=None, error=None, created_at=None, - follow_up_to_run_id=None, ): now = datetime.now(UTC).isoformat() self._runs[run_id] = { @@ -41,7 +40,6 @@ class MemoryRunStore(RunStore): "metadata": metadata or {}, "kwargs": kwargs or {}, "error": error, - "follow_up_to_run_id": follow_up_to_run_id, "created_at": created_at or now, "updated_at": now, } diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 506ab00cd..c018bcabd 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -51,7 +51,6 @@ class RunContext: event_store: Any | None = field(default=None) run_events_config: Any | None = field(default=None) thread_store: Any | None = field(default=None) - follow_up_to_run_id: str | None = field(default=None) async def run_agent( @@ -76,7 +75,6 @@ async def run_agent( event_store = ctx.event_store run_events_config = ctx.run_events_config thread_store = ctx.thread_store - follow_up_to_run_id = ctx.follow_up_to_run_id run_id = record.run_id thread_id = record.thread_id @@ -113,22 +111,6 @@ async def run_agent( track_token_usage=getattr(run_events_config, "track_token_usage", True), ) - human_msg = _extract_human_message(graph_input) - if human_msg is not None: - msg_metadata = {} - if follow_up_to_run_id: - msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id - await event_store.put( - thread_id=thread_id, - run_id=run_id, - event_type="human_message", - category="message", - content=human_msg.model_dump(), - metadata=msg_metadata or None, - ) - content = human_msg.content - journal.set_first_human_message(content if isinstance(content, str) else str(content)) - # 1. Mark running await run_manager.set_status(run_id, RunStatus.running) @@ -166,12 +148,13 @@ async def run_agent( # Inject runtime context so middlewares can access thread_id # (langgraph-cli does this automatically; we must do it manually) - runtime = Runtime(context={"thread_id": thread_id}, store=store) + runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store) # If the caller already set a ``context`` key (LangGraph >= 0.6.0 # prefers it over ``configurable`` for thread-level data), make # sure ``thread_id`` is available there too. if "context" in config and isinstance(config["context"], dict): config["context"].setdefault("thread_id", thread_id) + config["context"].setdefault("run_id", run_id) config.setdefault("configurable", {})["__pregel_runtime"] = runtime # Inject RunJournal as a LangChain callback handler. diff --git a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx index 7b288a40d..cf3a3e381 100644 --- a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx @@ -46,7 +46,13 @@ export default function AgentChatPage() { const [settings, setSettings] = useThreadSettings(threadId); const { showNotification } = useNotification(); - const [thread, sendMessage] = useThreadStream({ + const { + thread, + sendMessage, + isHistoryLoading, + hasMoreHistory, + loadMoreHistory, + } = useThreadStream({ threadId: isNewThread ? undefined : threadId, context: { ...settings.context, agent_name: agent_name }, onStart: (createdThreadId) => { @@ -141,6 +147,9 @@ export default function AgentChatPage() { threadId={threadId} thread={thread} paddingBottom={messageListPaddingBottom} + hasMoreHistory={hasMoreHistory} + loadMoreHistory={loadMoreHistory} + isHistoryLoading={isHistoryLoading} /> diff --git a/frontend/src/app/workspace/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/chats/[thread_id]/page.tsx index c5ff83dec..98e53e1c9 100644 --- a/frontend/src/app/workspace/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/chats/[thread_id]/page.tsx @@ -1,6 +1,6 @@ "use client"; -import { useCallback, useEffect, useState } from "react"; +import { useCallback, useEffect, useRef, useState } from "react"; import { type PromptInputMessage } from "@/components/ai-elements/prompt-input"; import { ArtifactTrigger } from "@/components/workspace/artifacts"; @@ -35,19 +35,30 @@ export default function ChatPage() { const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } = useThreadChat(); const [settings, setSettings] = useThreadSettings(threadId); - const [mounted, setMounted] = useState(false); + const mountedRef = useRef(false); useSpecificChatMode(); useEffect(() => { - setMounted(true); + mountedRef.current = true; }, []); const { showNotification } = useNotification(); - const [thread, sendMessage, isUploading] = useThreadStream({ + const { + thread, + sendMessage, + isUploading, + isHistoryLoading, + hasMoreHistory, + loadMoreHistory, + } = useThreadStream({ threadId: isNewThread ? undefined : threadId, context: settings.context, isMock, + onSend: (_threadId) => { + setThreadId(_threadId); + setIsNewThread(false); + }, onStart: (createdThreadId) => { setThreadId(createdThreadId); setIsNewThread(false); @@ -115,6 +126,9 @@ export default function ChatPage() { threadId={threadId} thread={thread} paddingBottom={messageListPaddingBottom} + hasMoreHistory={hasMoreHistory} + loadMoreHistory={loadMoreHistory} + isHistoryLoading={isHistoryLoading} />
@@ -138,7 +152,7 @@ export default function ChatPage() { />
- {mounted ? ( + {mountedRef.current ? (