mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
Refactor API fetch calls to use a unified fetch function; enhance chat history loading with new hooks and UI components
- Replaced `fetchWithAuth` with a generic `fetch` function across various API modules for consistency. - Updated `useThreadStream` and `useThreadHistory` hooks to manage chat history loading, including loading states and pagination. - Introduced `LoadMoreHistoryIndicator` component for better user experience when loading more chat history. - Enhanced message handling in `MessageList` to accommodate new loading states and history management. - Added support for run messages in the thread context, improving the overall message handling logic. - Updated translations for loading indicators in English and Chinese.
This commit is contained in:
parent
7b9d224b3a
commit
df63c104a7
@ -8,13 +8,17 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.runtime import RunContext, RunManager
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
from deerflow.runtime import RunContext, RunManager, StreamBridge
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
@ -22,6 +26,9 @@ if TYPE_CHECKING:
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||
@ -84,25 +91,25 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _require(attr: str, label: str):
|
||||
def _require(attr: str, label: str) -> Callable[[Request], T]:
|
||||
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
|
||||
|
||||
def dep(request: Request):
|
||||
def dep(request: Request) -> T:
|
||||
val = getattr(request.app.state, attr, None)
|
||||
if val is None:
|
||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||
return val
|
||||
return cast(T, val)
|
||||
|
||||
dep.__name__ = dep.__qualname__ = f"get_{attr}"
|
||||
return dep
|
||||
|
||||
|
||||
get_stream_bridge = _require("stream_bridge", "Stream bridge")
|
||||
get_run_manager = _require("run_manager", "Run manager")
|
||||
get_checkpointer = _require("checkpointer", "Checkpointer")
|
||||
get_run_event_store = _require("run_event_store", "Run event store")
|
||||
get_feedback_repo = _require("feedback_repo", "Feedback")
|
||||
get_run_store = _require("run_store", "Run store")
|
||||
get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge")
|
||||
get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager")
|
||||
get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer")
|
||||
get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store")
|
||||
get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback")
|
||||
get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store")
|
||||
|
||||
|
||||
def get_store(request: Request):
|
||||
@ -121,10 +128,7 @@ def get_thread_store(request: Request) -> ThreadMetaStore:
|
||||
def get_run_context(request: Request) -> RunContext:
|
||||
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
||||
|
||||
Returns a *base* context with infrastructure dependencies. Callers that
|
||||
need per-run fields (e.g. ``follow_up_to_run_id``) should use
|
||||
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
|
||||
to :func:`run_agent`.
|
||||
Returns a *base* context with infrastructure dependencies.
|
||||
"""
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
@ -137,7 +141,6 @@ def get_run_context(request: Request) -> RunContext:
|
||||
)
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth helpers (used by authz.py and auth middleware)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -123,7 +123,8 @@ async def run_messages(
|
||||
run = await _resolve_run(run_id, request)
|
||||
event_store = get_run_event_store(request)
|
||||
rows = await event_store.list_messages_by_run(
|
||||
run["thread_id"], run_id,
|
||||
run["thread_id"],
|
||||
run_id,
|
||||
limit=limit + 1,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
|
||||
@ -54,7 +54,6 @@ class RunCreateRequest(BaseModel):
|
||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
@ -312,11 +311,15 @@ async def list_thread_messages(
|
||||
if i in last_ai_indices:
|
||||
run_id = msg["run_id"]
|
||||
fb = feedback_map.get(run_id)
|
||||
msg["feedback"] = {
|
||||
"feedback_id": fb["feedback_id"],
|
||||
"rating": fb["rating"],
|
||||
"comment": fb.get("comment"),
|
||||
} if fb else None
|
||||
msg["feedback"] = (
|
||||
{
|
||||
"feedback_id": fb["feedback_id"],
|
||||
"rating": fb["rating"],
|
||||
"comment": fb.get("comment"),
|
||||
}
|
||||
if fb
|
||||
else None
|
||||
)
|
||||
else:
|
||||
msg["feedback"] = None
|
||||
|
||||
@ -339,7 +342,8 @@ async def list_run_messages(
|
||||
"""
|
||||
event_store = get_run_event_store(request)
|
||||
rows = await event_store.list_messages_by_run(
|
||||
thread_id, run_id,
|
||||
thread_id,
|
||||
run_id,
|
||||
limit=limit + 1,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
|
||||
@ -56,7 +56,7 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
||||
|
||||
|
||||
@router.post("", response_model=UploadResponse)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=False)
|
||||
async def upload_files(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
|
||||
@ -195,21 +195,6 @@ async def start_run(
|
||||
|
||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||
|
||||
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
|
||||
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
|
||||
if follow_up_to_run_id is None:
|
||||
run_store = get_run_store(request)
|
||||
try:
|
||||
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
|
||||
if recent_runs and recent_runs[0].get("status") == "success":
|
||||
follow_up_to_run_id = recent_runs[0]["run_id"]
|
||||
except Exception:
|
||||
pass # Don't block run creation
|
||||
|
||||
# Enrich base context with per-run field
|
||||
if follow_up_to_run_id:
|
||||
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
|
||||
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
@ -218,7 +203,6 @@ async def start_run(
|
||||
metadata=body.metadata or {},
|
||||
kwargs={"input": body.input, "config": body.config},
|
||||
multitask_strategy=body.multitask_strategy,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
)
|
||||
except ConflictError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||
@ -9,6 +9,7 @@ from deerflow.agents.middlewares.clarification_middleware import ClarificationMi
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from deerflow.agents.middlewares.summarization_middleware import SummarizationMiddleware
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
||||
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
||||
|
||||
@ -283,7 +283,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
# the conversation; injecting one mid-conversation crashes
|
||||
# langchain_anthropic's _format_messages(). HumanMessage works
|
||||
# with all providers. See #1299.
|
||||
return {"messages": [HumanMessage(content=warning)]}
|
||||
return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -0,0 +1,13 @@
|
||||
from typing import override
|
||||
|
||||
from langchain.agents.middleware import SummarizationMiddleware as BaseSummarizationMiddleware
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
|
||||
|
||||
class SummarizationMiddleware(BaseSummarizationMiddleware):
|
||||
@override
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
"""Override the base implementation to let the human message with the special name 'summary'.
|
||||
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
||||
"""
|
||||
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
||||
@ -1,8 +1,10 @@
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
@ -97,8 +99,20 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
paths = self._create_thread_directories(thread_id, user_id=user_id)
|
||||
logger.debug("Created thread data directories for thread %s", thread_id)
|
||||
|
||||
messages = list(state.get("messages", []))
|
||||
last_message = messages[-1] if messages else None
|
||||
|
||||
if last_message and isinstance(last_message, HumanMessage):
|
||||
messages[-1] = HumanMessage(
|
||||
content=last_message.content,
|
||||
id=last_message.id,
|
||||
name=last_message.name or "user-input",
|
||||
additional_kwargs={**last_message.additional_kwargs, "run_id": runtime.context.get("run_id"), "timestamp": datetime.now(UTC).isoformat()},
|
||||
)
|
||||
|
||||
return {
|
||||
"thread_data": {
|
||||
**paths,
|
||||
}
|
||||
},
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
@ -279,6 +279,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
updated_message = HumanMessage(
|
||||
content=f"{files_message}\n\n{original_content}",
|
||||
id=last_message.id,
|
||||
name=last_message.name,
|
||||
additional_kwargs=last_message.additional_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -6,7 +6,10 @@ handles token usage accumulation.
|
||||
|
||||
Key design decisions:
|
||||
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
||||
- on_chat_model_start captures structured prompts as llm_request (OpenAI format)
|
||||
- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and
|
||||
extracts the first human message for run.input, because it is more reliable than
|
||||
on_chain_start (fires on every node) — messages here are fully structured.
|
||||
- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation.
|
||||
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
||||
- Token usage accumulated in memory, written to RunRow on run completion
|
||||
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
||||
@ -18,10 +21,12 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
@ -72,34 +77,39 @@ class RunJournal(BaseCallbackHandler):
|
||||
# LLM request/response tracking
|
||||
self._llm_call_index = 0
|
||||
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
||||
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
|
||||
|
||||
# Tool call ID cache
|
||||
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_start",
|
||||
category="lifecycle",
|
||||
metadata={"input_preview": str(inputs)[:500]},
|
||||
)
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
caller = self._identify_caller(tags)
|
||||
if parent_run_id is None:
|
||||
# Root graph invocation — emit a single trace event for the run start.
|
||||
chain_name = (serialized or {}).get("name", "unknown")
|
||||
self._put(
|
||||
event_type="run.start",
|
||||
category="trace",
|
||||
content={"chain": chain_name},
|
||||
metadata={"caller": caller, **(metadata or {})},
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
|
||||
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
|
||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_error",
|
||||
category="lifecycle",
|
||||
event_type="run.error",
|
||||
category="error",
|
||||
content=str(error),
|
||||
metadata={"error_type": type(error).__name__},
|
||||
)
|
||||
@ -107,266 +117,132 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
# -- LLM callbacks --
|
||||
|
||||
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
"""Capture structured prompt messages for llm_request event."""
|
||||
from deerflow.runtime.converters import langchain_messages_to_openai
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict,
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Capture structured prompt messages for llm_request event.
|
||||
|
||||
This is also the canonical place to extract the first human message:
|
||||
messages are fully structured here, it fires only on real LLM calls,
|
||||
and the content is never compressed by checkpoint trimming.
|
||||
"""
|
||||
rid = str(run_id)
|
||||
self._llm_start_times[rid] = time.monotonic()
|
||||
self._llm_call_index += 1
|
||||
# Mark this run_id as seen so on_llm_end knows not to increment again.
|
||||
self._cached_prompts[rid] = []
|
||||
|
||||
model_name = serialized.get("name", "")
|
||||
self._cached_models[rid] = model_name
|
||||
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
|
||||
|
||||
# Convert the first message list (LangChain passes list-of-lists)
|
||||
prompt_msgs = messages[0] if messages else []
|
||||
openai_msgs = langchain_messages_to_openai(prompt_msgs)
|
||||
self._cached_prompts[rid] = openai_msgs
|
||||
# Capture the first human message sent to any LLM in this run.
|
||||
if not self._first_human_msg:
|
||||
for batch in messages.reversed():
|
||||
for m in batch.reversed():
|
||||
if isinstance(m, HumanMessage) and m.name != "summary":
|
||||
caller = self._identify_caller(tags)
|
||||
self.set_first_human_message(m.text)
|
||||
self._put(
|
||||
event_type="llm.human.input",
|
||||
category="message",
|
||||
content=m.model_dump(),
|
||||
metadata={"caller": caller},
|
||||
)
|
||||
break
|
||||
if self._first_human_msg:
|
||||
break
|
||||
|
||||
caller = self._identify_caller(kwargs)
|
||||
self._put(
|
||||
event_type="llm_request",
|
||||
category="trace",
|
||||
content={"model": model_name, "messages": openai_msgs},
|
||||
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
|
||||
)
|
||||
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||
|
||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.converters import langchain_to_openai_completion
|
||||
|
||||
try:
|
||||
message = response.generations[0][0].message
|
||||
except (IndexError, AttributeError):
|
||||
logger.debug("on_llm_end: could not extract message from response")
|
||||
return
|
||||
|
||||
caller = self._identify_caller(kwargs)
|
||||
|
||||
# Latency
|
||||
rid = str(run_id)
|
||||
start = self._llm_start_times.pop(rid, None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
|
||||
# Token usage from message
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# Resolve call index
|
||||
call_index = self._llm_call_index
|
||||
if rid not in self._cached_prompts:
|
||||
# Fallback: on_chat_model_start was not called
|
||||
self._llm_call_index += 1
|
||||
call_index = self._llm_call_index
|
||||
|
||||
# Clean up caches
|
||||
self._cached_prompts.pop(rid, None)
|
||||
self._cached_models.pop(rid, None)
|
||||
|
||||
# Trace event: llm_response (OpenAI completion format)
|
||||
content = getattr(message, "content", "")
|
||||
self._put(
|
||||
event_type="llm_response",
|
||||
category="trace",
|
||||
content=langchain_to_openai_completion(message),
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
|
||||
# Message events: only lead_agent gets message-category events.
|
||||
# Content uses message.model_dump() to align with checkpoint format.
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
if caller == "lead_agent":
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||
if tool_calls:
|
||||
# ai_tool_call: agent decided to use tools
|
||||
self._put(
|
||||
event_type="ai_tool_call",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
|
||||
)
|
||||
elif isinstance(content, str) and content:
|
||||
# ai_message: final text reply
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "stop"},
|
||||
)
|
||||
self._last_ai_msg = content
|
||||
self._msg_count += 1
|
||||
|
||||
# Token accumulation
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None:
|
||||
messages: list[AnyMessage] = []
|
||||
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
|
||||
for generation in response.generations:
|
||||
for gen in generation:
|
||||
if hasattr(gen, "message"):
|
||||
messages.append(gen.message)
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}")
|
||||
|
||||
for message in messages:
|
||||
caller = self._identify_caller(tags)
|
||||
|
||||
# Latency
|
||||
rid = str(run_id)
|
||||
start = self._llm_start_times.pop(rid, None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
|
||||
# Token usage from message
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# Resolve call index
|
||||
call_index = self._llm_call_index
|
||||
if rid not in self._cached_prompts:
|
||||
# Fallback: on_chat_model_start was not called
|
||||
self._llm_call_index += 1
|
||||
call_index = self._llm_call_index
|
||||
|
||||
# Trace event: llm_response (OpenAI completion format)
|
||||
self._put(
|
||||
event_type="llm.ai.response",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
|
||||
# Token accumulation
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm_error", category="trace", content=str(error))
|
||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||
|
||||
# -- Tool callbacks --
|
||||
def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs):
|
||||
"""Handle tool start event, cache tool call ID for later correlation"""
|
||||
tool_call_id = str(run_id)
|
||||
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
|
||||
|
||||
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
tool_call_id = kwargs.get("tool_call_id")
|
||||
if tool_call_id:
|
||||
self._tool_call_ids[str(run_id)] = tool_call_id
|
||||
self._put(
|
||||
event_type="tool_start",
|
||||
category="trace",
|
||||
metadata={
|
||||
"tool_name": serialized.get("name", ""),
|
||||
"tool_call_id": tool_call_id,
|
||||
"args": str(input_str)[:2000],
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
# Tools that update graph state return a ``Command`` (e.g.
|
||||
# ``present_files``). LangGraph later unwraps the inner ToolMessage
|
||||
# into checkpoint state, so to stay checkpoint-aligned we must
|
||||
# extract it here rather than storing ``str(Command(...))``.
|
||||
if isinstance(output, Command):
|
||||
update = getattr(output, "update", None) or {}
|
||||
inner_msgs = update.get("messages") if isinstance(update, dict) else None
|
||||
if isinstance(inner_msgs, list):
|
||||
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
|
||||
if inner_tool_msg is not None:
|
||||
output = inner_tool_msg
|
||||
|
||||
# Extract fields from ToolMessage object when LangChain provides one.
|
||||
# LangChain's _format_output wraps tool results into a ToolMessage
|
||||
# with tool_call_id, name, status, and artifact — more complete than
|
||||
# what kwargs alone provides.
|
||||
if isinstance(output, ToolMessage):
|
||||
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = output.name or kwargs.get("name", "")
|
||||
status = getattr(output, "status", "success") or "success"
|
||||
content_str = output.content if isinstance(output.content, str) else str(output.content)
|
||||
# Use model_dump() for checkpoint-aligned message content.
|
||||
# Override tool_call_id if it was resolved from cache.
|
||||
msg_content = output.model_dump()
|
||||
if msg_content.get("tool_call_id") != tool_call_id:
|
||||
msg_content["tool_call_id"] = tool_call_id
|
||||
else:
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
status = "success"
|
||||
content_str = str(output)
|
||||
# Construct checkpoint-aligned dict when output is a plain string.
|
||||
msg_content = ToolMessage(
|
||||
content=content_str,
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status=status,
|
||||
).model_dump()
|
||||
|
||||
# Trace event (always)
|
||||
self._put(
|
||||
event_type="tool_end",
|
||||
category="trace",
|
||||
content=content_str,
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result (checkpoint-aligned model_dump format)
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": status},
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
|
||||
# Trace event
|
||||
self._put(
|
||||
event_type="tool_error",
|
||||
category="trace",
|
||||
content=str(error),
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result with error status (checkpoint-aligned)
|
||||
msg_content = ToolMessage(
|
||||
content=str(error),
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status="error",
|
||||
).model_dump()
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": "error"},
|
||||
)
|
||||
|
||||
# -- Custom event callback --
|
||||
|
||||
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
if name == "summarization":
|
||||
data_dict = data if isinstance(data, dict) else {}
|
||||
self._put(
|
||||
event_type="summarization",
|
||||
category="trace",
|
||||
content=data_dict.get("summary", ""),
|
||||
metadata={
|
||||
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
|
||||
"replaced_count": data_dict.get("replaced_count", 0),
|
||||
},
|
||||
)
|
||||
self._put(
|
||||
event_type="middleware:summarize",
|
||||
category="middleware",
|
||||
content={"role": "system", "content": data_dict.get("summary", "")},
|
||||
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
|
||||
)
|
||||
else:
|
||||
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
|
||||
self._put(
|
||||
event_type=name,
|
||||
category="trace",
|
||||
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
|
||||
)
|
||||
def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs):
|
||||
"""Handle tool end event, append message and clear node data"""
|
||||
try:
|
||||
if isinstance(output, ToolMessage):
|
||||
msg = cast(ToolMessage, output)
|
||||
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
||||
elif isinstance(output, Command):
|
||||
cmd = cast(Command, output)
|
||||
messages = cmd.update.get("messages", [])
|
||||
for message in messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}")
|
||||
finally:
|
||||
logger.info(f"Tool end for node {run_id}")
|
||||
|
||||
# -- Internal methods --
|
||||
|
||||
@ -431,8 +307,9 @@ class RunJournal(BaseCallbackHandler):
|
||||
if exc:
|
||||
logger.warning("Journal flush task failed: %s", exc)
|
||||
|
||||
def _identify_caller(self, kwargs: dict) -> str:
|
||||
for tag in kwargs.get("tags") or []:
|
||||
def _identify_caller(self, tags: list[str] | None, **kwargs) -> str:
|
||||
_tags = tags or kwargs.get("tags", [])
|
||||
for tag in _tags:
|
||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||
return tag
|
||||
# Default to lead_agent: the main agent graph does not inject
|
||||
|
||||
@ -54,7 +54,7 @@ class RunManager:
|
||||
self._lock = asyncio.Lock()
|
||||
self._store = store
|
||||
|
||||
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
|
||||
async def _persist_to_store(self, record: RunRecord) -> None:
|
||||
"""Best-effort persist run record to backing store."""
|
||||
if self._store is None:
|
||||
return
|
||||
@ -68,7 +68,6 @@ class RunManager:
|
||||
metadata=record.metadata or {},
|
||||
kwargs=record.kwargs or {},
|
||||
created_at=record.created_at,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
||||
@ -90,7 +89,6 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Create a new pending run and register it."""
|
||||
run_id = str(uuid.uuid4())
|
||||
@ -109,7 +107,7 @@ class RunManager:
|
||||
)
|
||||
async with self._lock:
|
||||
self._runs[run_id] = record
|
||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||
await self._persist_to_store(record)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
@ -122,7 +120,7 @@ class RunManager:
|
||||
async with self._lock:
|
||||
# Dict insertion order matches creation order, so reversing it gives
|
||||
# us deterministic newest-first results even when timestamps tie.
|
||||
return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id]
|
||||
return [r for r in self._runs.values() if r.thread_id == thread_id]
|
||||
|
||||
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
||||
"""Transition a run to a new status."""
|
||||
@ -176,7 +174,6 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Atomically check for inflight runs and create a new one.
|
||||
|
||||
@ -230,7 +227,7 @@ class RunManager:
|
||||
)
|
||||
self._runs[run_id] = record
|
||||
|
||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||
await self._persist_to_store(record)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
|
||||
@ -29,7 +29,6 @@ class RunStore(abc.ABC):
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
created_at: str | None = None,
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -28,7 +28,6 @@ class MemoryRunStore(RunStore):
|
||||
kwargs=None,
|
||||
error=None,
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
now = datetime.now(UTC).isoformat()
|
||||
self._runs[run_id] = {
|
||||
@ -41,7 +40,6 @@ class MemoryRunStore(RunStore):
|
||||
"metadata": metadata or {},
|
||||
"kwargs": kwargs or {},
|
||||
"error": error,
|
||||
"follow_up_to_run_id": follow_up_to_run_id,
|
||||
"created_at": created_at or now,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
@ -51,7 +51,6 @@ class RunContext:
|
||||
event_store: Any | None = field(default=None)
|
||||
run_events_config: Any | None = field(default=None)
|
||||
thread_store: Any | None = field(default=None)
|
||||
follow_up_to_run_id: str | None = field(default=None)
|
||||
|
||||
|
||||
async def run_agent(
|
||||
@ -76,7 +75,6 @@ async def run_agent(
|
||||
event_store = ctx.event_store
|
||||
run_events_config = ctx.run_events_config
|
||||
thread_store = ctx.thread_store
|
||||
follow_up_to_run_id = ctx.follow_up_to_run_id
|
||||
|
||||
run_id = record.run_id
|
||||
thread_id = record.thread_id
|
||||
@ -113,22 +111,6 @@ async def run_agent(
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
human_msg = _extract_human_message(graph_input)
|
||||
if human_msg is not None:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=human_msg.model_dump(),
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
content = human_msg.content
|
||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||
|
||||
# 1. Mark running
|
||||
await run_manager.set_status(run_id, RunStatus.running)
|
||||
|
||||
@ -166,12 +148,13 @@ async def run_agent(
|
||||
|
||||
# Inject runtime context so middlewares can access thread_id
|
||||
# (langgraph-cli does this automatically; we must do it manually)
|
||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
|
||||
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
||||
# prefers it over ``configurable`` for thread-level data), make
|
||||
# sure ``thread_id`` is available there too.
|
||||
if "context" in config and isinstance(config["context"], dict):
|
||||
config["context"].setdefault("thread_id", thread_id)
|
||||
config["context"].setdefault("run_id", run_id)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
# Inject RunJournal as a LangChain callback handler.
|
||||
|
||||
@ -46,7 +46,13 @@ export default function AgentChatPage() {
|
||||
const [settings, setSettings] = useThreadSettings(threadId);
|
||||
|
||||
const { showNotification } = useNotification();
|
||||
const [thread, sendMessage] = useThreadStream({
|
||||
const {
|
||||
thread,
|
||||
sendMessage,
|
||||
isHistoryLoading,
|
||||
hasMoreHistory,
|
||||
loadMoreHistory,
|
||||
} = useThreadStream({
|
||||
threadId: isNewThread ? undefined : threadId,
|
||||
context: { ...settings.context, agent_name: agent_name },
|
||||
onStart: (createdThreadId) => {
|
||||
@ -141,6 +147,9 @@ export default function AgentChatPage() {
|
||||
threadId={threadId}
|
||||
thread={thread}
|
||||
paddingBottom={messageListPaddingBottom}
|
||||
hasMoreHistory={hasMoreHistory}
|
||||
loadMoreHistory={loadMoreHistory}
|
||||
isHistoryLoading={isHistoryLoading}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
|
||||
import { type PromptInputMessage } from "@/components/ai-elements/prompt-input";
|
||||
import { ArtifactTrigger } from "@/components/workspace/artifacts";
|
||||
@ -35,19 +35,30 @@ export default function ChatPage() {
|
||||
const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } =
|
||||
useThreadChat();
|
||||
const [settings, setSettings] = useThreadSettings(threadId);
|
||||
const [mounted, setMounted] = useState(false);
|
||||
const mountedRef = useRef(false);
|
||||
useSpecificChatMode();
|
||||
|
||||
useEffect(() => {
|
||||
setMounted(true);
|
||||
mountedRef.current = true;
|
||||
}, []);
|
||||
|
||||
const { showNotification } = useNotification();
|
||||
|
||||
const [thread, sendMessage, isUploading] = useThreadStream({
|
||||
const {
|
||||
thread,
|
||||
sendMessage,
|
||||
isUploading,
|
||||
isHistoryLoading,
|
||||
hasMoreHistory,
|
||||
loadMoreHistory,
|
||||
} = useThreadStream({
|
||||
threadId: isNewThread ? undefined : threadId,
|
||||
context: settings.context,
|
||||
isMock,
|
||||
onSend: (_threadId) => {
|
||||
setThreadId(_threadId);
|
||||
setIsNewThread(false);
|
||||
},
|
||||
onStart: (createdThreadId) => {
|
||||
setThreadId(createdThreadId);
|
||||
setIsNewThread(false);
|
||||
@ -115,6 +126,9 @@ export default function ChatPage() {
|
||||
threadId={threadId}
|
||||
thread={thread}
|
||||
paddingBottom={messageListPaddingBottom}
|
||||
hasMoreHistory={hasMoreHistory}
|
||||
loadMoreHistory={loadMoreHistory}
|
||||
isHistoryLoading={isHistoryLoading}
|
||||
/>
|
||||
</div>
|
||||
<div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4">
|
||||
@ -138,7 +152,7 @@ export default function ChatPage() {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{mounted ? (
|
||||
{mountedRef.current ? (
|
||||
<InputBox
|
||||
className={cn("bg-background/5 w-full -translate-y-4")}
|
||||
isNewThread={isNewThread}
|
||||
@ -170,7 +184,7 @@ export default function ChatPage() {
|
||||
<div
|
||||
aria-hidden="true"
|
||||
className={cn(
|
||||
"bg-background/5 h-32 w-full -translate-y-4 rounded-2xl border",
|
||||
"bg-background/5 h-32 w-full -translate-y-4 rounded-2xl",
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
@ -55,7 +55,7 @@ import {
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuSeparator,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { fetchWithAuth } from "@/core/api/fetcher";
|
||||
import { fetch } from "@/core/api/fetcher";
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { useModels } from "@/core/models/hooks";
|
||||
@ -155,6 +155,7 @@ export function InputBox({
|
||||
const [followupsLoading, setFollowupsLoading] = useState(false);
|
||||
const lastGeneratedForAiIdRef = useRef<string | null>(null);
|
||||
const wasStreamingRef = useRef(false);
|
||||
const messagesRef = useRef(thread.messages);
|
||||
|
||||
const [confirmOpen, setConfirmOpen] = useState(false);
|
||||
const [pendingSuggestion, setPendingSuggestion] = useState<string | null>(
|
||||
@ -354,6 +355,10 @@ export function InputBox({
|
||||
followupsVisibilityChangeRef.current?.(showFollowups);
|
||||
}, [showFollowups]);
|
||||
|
||||
useEffect(() => {
|
||||
messagesRef.current = thread.messages;
|
||||
}, [thread.messages]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => followupsVisibilityChangeRef.current?.(false);
|
||||
}, []);
|
||||
@ -370,14 +375,16 @@ export function InputBox({
|
||||
return;
|
||||
}
|
||||
|
||||
const lastAi = [...thread.messages].reverse().find((m) => m.type === "ai");
|
||||
const lastAi = [...messagesRef.current]
|
||||
.reverse()
|
||||
.find((m) => m.type === "ai");
|
||||
const lastAiId = lastAi?.id ?? null;
|
||||
if (!lastAiId || lastAiId === lastGeneratedForAiIdRef.current) {
|
||||
return;
|
||||
}
|
||||
lastGeneratedForAiIdRef.current = lastAiId;
|
||||
|
||||
const recent = thread.messages
|
||||
const recent = messagesRef.current
|
||||
.filter((m) => m.type === "human" || m.type === "ai")
|
||||
.map((m) => {
|
||||
const role = m.type === "human" ? "user" : "assistant";
|
||||
@ -396,19 +403,16 @@ export function InputBox({
|
||||
setFollowupsLoading(true);
|
||||
setFollowups([]);
|
||||
|
||||
fetchWithAuth(
|
||||
`${getBackendBaseURL()}/api/threads/${threadId}/suggestions`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
messages: recent,
|
||||
n: 3,
|
||||
model_name: context.model_name ?? undefined,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
},
|
||||
)
|
||||
fetch(`${getBackendBaseURL()}/api/threads/${threadId}/suggestions`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
messages: recent,
|
||||
n: 3,
|
||||
model_name: context.model_name ?? undefined,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) {
|
||||
return { suggestions: [] as string[] };
|
||||
@ -430,7 +434,7 @@ export function InputBox({
|
||||
});
|
||||
|
||||
return () => controller.abort();
|
||||
}, [context.model_name, disabled, isMock, status, thread.messages, threadId]);
|
||||
}, [context.model_name, disabled, isMock, status, threadId]);
|
||||
|
||||
return (
|
||||
<div ref={promptRootRef} className="relative flex flex-col gap-4">
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import type { BaseStream } from "@langchain/langgraph-sdk/react";
|
||||
import { ChevronUpIcon, Loader2Icon } from "lucide-react";
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
|
||||
import {
|
||||
Conversation,
|
||||
ConversationContent,
|
||||
} from "@/components/ai-elements/conversation";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import {
|
||||
extractContentFromMessage,
|
||||
@ -18,7 +21,6 @@ import { useRehypeSplitWordsIntoSpans } from "@/core/rehype";
|
||||
import type { Subtask } from "@/core/tasks";
|
||||
import { useUpdateSubtask } from "@/core/tasks/context";
|
||||
import type { AgentThreadState } from "@/core/threads";
|
||||
import { useThreadMessageEnrichment } from "@/core/threads/hooks";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
import { ArtifactFileList } from "../artifacts/artifact-file-list";
|
||||
@ -33,22 +35,134 @@ import { SubtaskCard } from "./subtask-card";
|
||||
export const MESSAGE_LIST_DEFAULT_PADDING_BOTTOM = 160;
|
||||
export const MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM = 80;
|
||||
|
||||
const LOAD_MORE_HISTORY_THROTTLE_MS = 1200;
|
||||
|
||||
function LoadMoreHistoryIndicator({
|
||||
isLoading,
|
||||
hasMore,
|
||||
loadMore,
|
||||
}: {
|
||||
isLoading?: boolean;
|
||||
hasMore?: boolean;
|
||||
loadMore?: () => void;
|
||||
}) {
|
||||
const { t } = useI18n();
|
||||
const sentinelRef = useRef<HTMLDivElement | null>(null);
|
||||
const timeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const lastLoadRef = useRef(0);
|
||||
|
||||
const throttledLoadMore = useCallback(() => {
|
||||
if (!hasMore || isLoading) {
|
||||
return;
|
||||
}
|
||||
|
||||
const now = Date.now();
|
||||
const remaining =
|
||||
LOAD_MORE_HISTORY_THROTTLE_MS - (now - lastLoadRef.current);
|
||||
|
||||
if (remaining <= 0) {
|
||||
lastLoadRef.current = now;
|
||||
loadMore?.();
|
||||
return;
|
||||
}
|
||||
|
||||
if (timeoutRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
timeoutRef.current = setTimeout(() => {
|
||||
timeoutRef.current = null;
|
||||
if (!hasMore || isLoading) {
|
||||
return;
|
||||
}
|
||||
lastLoadRef.current = Date.now();
|
||||
loadMore?.();
|
||||
}, remaining);
|
||||
}, [hasMore, isLoading, loadMore]);
|
||||
|
||||
useEffect(() => {
|
||||
const element = sentinelRef.current;
|
||||
if (!element || !hasMore) {
|
||||
return;
|
||||
}
|
||||
|
||||
const observer = new IntersectionObserver(
|
||||
([entry]) => {
|
||||
if (entry?.isIntersecting) {
|
||||
throttledLoadMore();
|
||||
}
|
||||
},
|
||||
{
|
||||
rootMargin: "120px 0px 0px 0px",
|
||||
},
|
||||
);
|
||||
|
||||
observer.observe(element);
|
||||
|
||||
return () => {
|
||||
observer.disconnect();
|
||||
};
|
||||
}, [hasMore, throttledLoadMore]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (timeoutRef.current) {
|
||||
clearTimeout(timeoutRef.current);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
if (!hasMore && !isLoading) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div ref={sentinelRef} className="flex w-full justify-center">
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="text-muted-foreground hover:text-foreground rounded-full px-3"
|
||||
disabled={(isLoading ?? false) || !hasMore}
|
||||
onClick={throttledLoadMore}
|
||||
>
|
||||
{isLoading ? (
|
||||
<>
|
||||
<Loader2Icon className="mr-2 size-4 animate-spin" />
|
||||
{t.common.loading}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ChevronUpIcon className="mr-2 size-4" />
|
||||
{t.common.loadMore}
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function MessageList({
|
||||
className,
|
||||
threadId,
|
||||
thread,
|
||||
paddingBottom = MESSAGE_LIST_DEFAULT_PADDING_BOTTOM,
|
||||
hasMoreHistory,
|
||||
loadMoreHistory,
|
||||
isHistoryLoading,
|
||||
}: {
|
||||
className?: string;
|
||||
threadId: string;
|
||||
thread: BaseStream<AgentThreadState>;
|
||||
paddingBottom?: number;
|
||||
hasMoreHistory?: boolean;
|
||||
loadMoreHistory?: () => void;
|
||||
isHistoryLoading?: boolean;
|
||||
}) {
|
||||
const { t } = useI18n();
|
||||
const rehypePlugins = useRehypeSplitWordsIntoSpans(thread.isLoading);
|
||||
const updateSubtask = useUpdateSubtask();
|
||||
const messages = thread.messages;
|
||||
const { data: enrichment } = useThreadMessageEnrichment(threadId);
|
||||
|
||||
if (thread.isThreadLoading && messages.length === 0) {
|
||||
return <MessageListSkeleton />;
|
||||
@ -57,19 +171,21 @@ export function MessageList({
|
||||
<Conversation
|
||||
className={cn("flex size-full flex-col justify-center", className)}
|
||||
>
|
||||
<ConversationContent className="mx-auto w-full max-w-(--container-width-md) gap-8 pt-12">
|
||||
<ConversationContent className="mx-auto w-full max-w-(--container-width-md) gap-8 pt-8">
|
||||
<LoadMoreHistoryIndicator
|
||||
isLoading={isHistoryLoading}
|
||||
hasMore={hasMoreHistory}
|
||||
loadMore={loadMoreHistory}
|
||||
/>
|
||||
{groupMessages(messages, (group) => {
|
||||
if (group.type === "human" || group.type === "assistant") {
|
||||
return group.messages.map((msg) => {
|
||||
const entry = msg.id ? enrichment?.get(msg.id) : undefined;
|
||||
return (
|
||||
<MessageListItem
|
||||
key={`${group.id}/${msg.id}`}
|
||||
threadId={threadId}
|
||||
message={msg}
|
||||
isLoading={thread.isLoading}
|
||||
runId={entry?.run_id}
|
||||
feedback={entry?.feedback}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@ -5,7 +5,7 @@ import { useState } from "react";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { fetchWithAuth, getCsrfHeaders } from "@/core/api/fetcher";
|
||||
import { fetch, getCsrfHeaders } from "@/core/api/fetcher";
|
||||
import { useAuth } from "@/core/auth/AuthProvider";
|
||||
import { parseAuthError } from "@/core/auth/types";
|
||||
|
||||
@ -36,7 +36,7 @@ export function AccountSettingsPage() {
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await fetchWithAuth("/api/v1/auth/change-password", {
|
||||
const res = await fetch("/api/v1/auth/change-password", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { fetchWithAuth } from "@/core/api/fetcher";
|
||||
import { fetch } from "@/core/api/fetcher";
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
|
||||
import type { Agent, CreateAgentRequest, UpdateAgentRequest } from "./types";
|
||||
@ -29,7 +29,7 @@ export async function getAgent(name: string): Promise<Agent> {
|
||||
}
|
||||
|
||||
export async function createAgent(request: CreateAgentRequest): Promise<Agent> {
|
||||
const res = await fetchWithAuth(`${getBackendBaseURL()}/api/agents`, {
|
||||
const res = await fetch(`${getBackendBaseURL()}/api/agents`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
@ -45,7 +45,7 @@ export async function updateAgent(
|
||||
name: string,
|
||||
request: UpdateAgentRequest,
|
||||
): Promise<Agent> {
|
||||
const res = await fetchWithAuth(`${getBackendBaseURL()}/api/agents/${name}`, {
|
||||
const res = await fetch(`${getBackendBaseURL()}/api/agents/${name}`, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
@ -58,7 +58,7 @@ export async function updateAgent(
|
||||
}
|
||||
|
||||
export async function deleteAgent(name: string): Promise<void> {
|
||||
const res = await fetchWithAuth(`${getBackendBaseURL()}/api/agents/${name}`, {
|
||||
const res = await fetch(`${getBackendBaseURL()}/api/agents/${name}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
if (!res.ok) throw new Error(`Failed to delete agent: ${res.statusText}`);
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { getBackendBaseURL } from "../config";
|
||||
|
||||
import { fetchWithAuth } from "./fetcher";
|
||||
import { fetch } from "./fetcher";
|
||||
|
||||
export interface FeedbackData {
|
||||
feedback_id: string;
|
||||
@ -14,7 +14,7 @@ export async function upsertFeedback(
|
||||
rating: number,
|
||||
comment?: string,
|
||||
): Promise<FeedbackData> {
|
||||
const res = await fetchWithAuth(
|
||||
const res = await fetch(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`,
|
||||
{
|
||||
method: "PUT",
|
||||
@ -32,7 +32,7 @@ export async function deleteFeedback(
|
||||
threadId: string,
|
||||
runId: string,
|
||||
): Promise<void> {
|
||||
const res = await fetchWithAuth(
|
||||
const res = await fetch(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`,
|
||||
{ method: "DELETE" },
|
||||
);
|
||||
|
||||
@ -53,7 +53,7 @@ export function readCsrfCookie(): string | null {
|
||||
* preserved; the helper only ADDS the CSRF header when it isn't already
|
||||
* present, so explicit overrides win.
|
||||
*/
|
||||
export async function fetchWithAuth(
|
||||
export async function fetch(
|
||||
input: RequestInfo | string,
|
||||
init?: RequestInit,
|
||||
): Promise<Response> {
|
||||
@ -74,7 +74,7 @@ export async function fetchWithAuth(
|
||||
}
|
||||
}
|
||||
|
||||
const res = await fetch(url, {
|
||||
const res = await globalThis.fetch(url, {
|
||||
...init,
|
||||
headers,
|
||||
credentials: "include",
|
||||
|
||||
@ -29,6 +29,7 @@ export const enUS: Translations = {
|
||||
close: "Close",
|
||||
more: "More",
|
||||
search: "Search",
|
||||
loadMore: "Load more",
|
||||
download: "Download",
|
||||
thinking: "Thinking",
|
||||
artifacts: "Artifacts",
|
||||
|
||||
@ -18,6 +18,7 @@ export interface Translations {
|
||||
close: string;
|
||||
more: string;
|
||||
search: string;
|
||||
loadMore: string;
|
||||
download: string;
|
||||
thinking: string;
|
||||
artifacts: string;
|
||||
|
||||
@ -29,6 +29,7 @@ export const zhCN: Translations = {
|
||||
close: "关闭",
|
||||
more: "更多",
|
||||
search: "搜索",
|
||||
loadMore: "加载更多",
|
||||
download: "下载",
|
||||
thinking: "思考",
|
||||
artifacts: "文件",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { fetchWithAuth } from "@/core/api/fetcher";
|
||||
import { fetch } from "@/core/api/fetcher";
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
|
||||
import type { MCPConfig } from "./types";
|
||||
@ -9,15 +9,12 @@ export async function loadMCPConfig() {
|
||||
}
|
||||
|
||||
export async function updateMCPConfig(config: MCPConfig) {
|
||||
const response = await fetchWithAuth(
|
||||
`${getBackendBaseURL()}/api/mcp/config`,
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(config),
|
||||
const response = await fetch(`${getBackendBaseURL()}/api/mcp/config`, {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
);
|
||||
body: JSON.stringify(config),
|
||||
});
|
||||
return response.json();
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { fetchWithAuth } from "../api/fetcher";
|
||||
import { fetch } from "../api/fetcher";
|
||||
import { getBackendBaseURL } from "../config";
|
||||
|
||||
import type {
|
||||
@ -86,14 +86,14 @@ export async function loadMemory(): Promise<UserMemory> {
|
||||
}
|
||||
|
||||
export async function clearMemory(): Promise<UserMemory> {
|
||||
const response = await fetchWithAuth(`${getBackendBaseURL()}/api/memory`, {
|
||||
const response = await fetch(`${getBackendBaseURL()}/api/memory`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
return readMemoryResponse(response, "Failed to clear memory");
|
||||
}
|
||||
|
||||
export async function deleteMemoryFact(factId: string): Promise<UserMemory> {
|
||||
const response = await fetchWithAuth(
|
||||
const response = await fetch(
|
||||
`${getBackendBaseURL()}/api/memory/facts/${encodeURIComponent(factId)}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
@ -108,32 +108,26 @@ export async function exportMemory(): Promise<UserMemory> {
|
||||
}
|
||||
|
||||
export async function importMemory(memory: UserMemory): Promise<UserMemory> {
|
||||
const response = await fetchWithAuth(
|
||||
`${getBackendBaseURL()}/api/memory/import`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(memory),
|
||||
const response = await fetch(`${getBackendBaseURL()}/api/memory/import`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
);
|
||||
body: JSON.stringify(memory),
|
||||
});
|
||||
return readMemoryResponse(response, "Failed to import memory");
|
||||
}
|
||||
|
||||
export async function createMemoryFact(
|
||||
input: MemoryFactInput,
|
||||
): Promise<UserMemory> {
|
||||
const response = await fetchWithAuth(
|
||||
`${getBackendBaseURL()}/api/memory/facts`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(input),
|
||||
const response = await fetch(`${getBackendBaseURL()}/api/memory/facts`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
);
|
||||
body: JSON.stringify(input),
|
||||
});
|
||||
return readMemoryResponse(response, "Failed to create memory fact");
|
||||
}
|
||||
|
||||
@ -141,7 +135,7 @@ export async function updateMemoryFact(
|
||||
factId: string,
|
||||
input: MemoryFactPatchInput,
|
||||
): Promise<UserMemory> {
|
||||
const response = await fetchWithAuth(
|
||||
const response = await fetch(
|
||||
`${getBackendBaseURL()}/api/memory/facts/${encodeURIComponent(factId)}`,
|
||||
{
|
||||
method: "PATCH",
|
||||
|
||||
@ -328,7 +328,11 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) {
|
||||
}
|
||||
|
||||
export function isHiddenFromUIMessage(message: Message) {
|
||||
return message.additional_kwargs?.hide_from_ui === true;
|
||||
return (
|
||||
message.additional_kwargs?.hide_from_ui === true ||
|
||||
message.name === "summary" ||
|
||||
message.name === "loop_warning"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { fetchWithAuth } from "@/core/api/fetcher";
|
||||
import { fetch } from "@/core/api/fetcher";
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
|
||||
import type { Skill } from "./type";
|
||||
@ -10,7 +10,7 @@ export async function loadSkills() {
|
||||
}
|
||||
|
||||
export async function enableSkill(skillName: string, enabled: boolean) {
|
||||
const response = await fetchWithAuth(
|
||||
const response = await fetch(
|
||||
`${getBackendBaseURL()}/api/skills/${skillName}`,
|
||||
{
|
||||
method: "PUT",
|
||||
@ -39,16 +39,13 @@ export interface InstallSkillResponse {
|
||||
export async function installSkill(
|
||||
request: InstallSkillRequest,
|
||||
): Promise<InstallSkillResponse> {
|
||||
const response = await fetchWithAuth(
|
||||
`${getBackendBaseURL()}/api/skills/install`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
const response = await fetch(`${getBackendBaseURL()}/api/skills/install`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
);
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
// Handle HTTP error responses (4xx, 5xx)
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
import type { AIMessage, Message } from "@langchain/langgraph-sdk";
|
||||
import type { AIMessage, Message, Run } from "@langchain/langgraph-sdk";
|
||||
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
|
||||
import { useStream } from "@langchain/langgraph-sdk/react";
|
||||
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
import type { PromptInputMessage } from "@/components/ai-elements/prompt-input";
|
||||
|
||||
import { getAPIClient } from "../api";
|
||||
import type { FeedbackData } from "../api/feedback";
|
||||
import { fetchWithAuth } from "../api/fetcher";
|
||||
import { fetch } from "../api/fetcher";
|
||||
import { getBackendBaseURL } from "../config";
|
||||
import { useI18n } from "../i18n/hooks";
|
||||
import type { FileInMessage } from "../messages/utils";
|
||||
@ -18,7 +17,7 @@ import { useUpdateSubtask } from "../tasks/context";
|
||||
import type { UploadedFileInfo } from "../uploads";
|
||||
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
|
||||
|
||||
import type { AgentThread, AgentThreadState } from "./types";
|
||||
import type { AgentThread, AgentThreadState, RunMessage } from "./types";
|
||||
|
||||
export type ToolEndEvent = {
|
||||
name: string;
|
||||
@ -29,7 +28,8 @@ export type ThreadStreamOptions = {
|
||||
threadId?: string | null | undefined;
|
||||
context: LocalSettings["context"];
|
||||
isMock?: boolean;
|
||||
onStart?: (threadId: string) => void;
|
||||
onSend?: (threadId: string) => void;
|
||||
onStart?: (threadId: string, runId: string) => void;
|
||||
onFinish?: (state: AgentThreadState) => void;
|
||||
onToolEnd?: (event: ToolEndEvent) => void;
|
||||
};
|
||||
@ -38,79 +38,41 @@ type SendMessageOptions = {
|
||||
additionalKwargs?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
function normalizeStoredRunId(runId: string | null): string | null {
|
||||
if (!runId) {
|
||||
return null;
|
||||
}
|
||||
function mergeMessages(
|
||||
historyMessages: Message[],
|
||||
threadMessages: Message[],
|
||||
optimisticMessages: Message[],
|
||||
): Message[] {
|
||||
const threadMessageIds = new Set(
|
||||
threadMessages
|
||||
.map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id))
|
||||
.filter(Boolean),
|
||||
);
|
||||
|
||||
const trimmed = runId.trim();
|
||||
if (!trimmed) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const queryIndex = trimmed.indexOf("?");
|
||||
if (queryIndex >= 0) {
|
||||
const params = new URLSearchParams(trimmed.slice(queryIndex + 1));
|
||||
const queryRunId = params.get("run_id")?.trim();
|
||||
if (queryRunId) {
|
||||
return queryRunId;
|
||||
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
|
||||
// Scan from the end: shrink cutoff while messages are already in thread, stop as soon as
|
||||
// we hit one that isn't — everything before that point is non-overlapping.
|
||||
let cutoff = historyMessages.length;
|
||||
for (let i = historyMessages.length - 1; i >= 0; i--) {
|
||||
const msg = historyMessages[i];
|
||||
if (!msg) {
|
||||
continue;
|
||||
}
|
||||
if (
|
||||
(msg?.id && threadMessageIds.has(msg.id)) ||
|
||||
("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id))
|
||||
) {
|
||||
cutoff = i;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const pathWithoutQueryOrHash = trimmed.split(/[?#]/, 1)[0]?.trim() ?? "";
|
||||
if (!pathWithoutQueryOrHash) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const runsMarker = "/runs/";
|
||||
const runsIndex = pathWithoutQueryOrHash.lastIndexOf(runsMarker);
|
||||
if (runsIndex >= 0) {
|
||||
const runIdAfterMarker = pathWithoutQueryOrHash
|
||||
.slice(runsIndex + runsMarker.length)
|
||||
.split("/", 1)[0]
|
||||
?.trim();
|
||||
if (runIdAfterMarker) {
|
||||
return runIdAfterMarker;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
const segments = pathWithoutQueryOrHash
|
||||
.split("/")
|
||||
.map((segment) => segment.trim())
|
||||
.filter(Boolean);
|
||||
return segments.at(-1) ?? null;
|
||||
}
|
||||
|
||||
function getRunMetadataStorage(): {
|
||||
getItem(key: `lg:stream:${string}`): string | null;
|
||||
setItem(key: `lg:stream:${string}`, value: string): void;
|
||||
removeItem(key: `lg:stream:${string}`): void;
|
||||
} {
|
||||
return {
|
||||
getItem(key) {
|
||||
const normalized = normalizeStoredRunId(
|
||||
window.sessionStorage.getItem(key),
|
||||
);
|
||||
if (normalized) {
|
||||
window.sessionStorage.setItem(key, normalized);
|
||||
return normalized;
|
||||
}
|
||||
window.sessionStorage.removeItem(key);
|
||||
return null;
|
||||
},
|
||||
setItem(key, value) {
|
||||
const normalized = normalizeStoredRunId(value);
|
||||
if (normalized) {
|
||||
window.sessionStorage.setItem(key, normalized);
|
||||
return;
|
||||
}
|
||||
window.sessionStorage.removeItem(key);
|
||||
},
|
||||
removeItem(key) {
|
||||
window.sessionStorage.removeItem(key);
|
||||
},
|
||||
};
|
||||
return [
|
||||
...historyMessages.slice(0, cutoff),
|
||||
...threadMessages,
|
||||
...optimisticMessages,
|
||||
];
|
||||
}
|
||||
|
||||
function getStreamErrorMessage(error: unknown): string {
|
||||
@ -140,6 +102,7 @@ export function useThreadStream({
|
||||
threadId,
|
||||
context,
|
||||
isMock,
|
||||
onSend,
|
||||
onStart,
|
||||
onFinish,
|
||||
onToolEnd,
|
||||
@ -151,17 +114,25 @@ export function useThreadStream({
|
||||
// and to allow access to the current thread id in onUpdateEvent
|
||||
const threadIdRef = useRef<string | null>(threadId ?? null);
|
||||
const startedRef = useRef(false);
|
||||
|
||||
const listeners = useRef({
|
||||
onSend,
|
||||
onStart,
|
||||
onFinish,
|
||||
onToolEnd,
|
||||
});
|
||||
|
||||
const {
|
||||
messages: history,
|
||||
hasMore: hasMoreHistory,
|
||||
loadMore: loadMoreHistory,
|
||||
loading: isHistoryLoading,
|
||||
appendMessages,
|
||||
} = useThreadHistory(onStreamThreadId ?? "");
|
||||
|
||||
// Keep listeners ref updated with latest callbacks
|
||||
useEffect(() => {
|
||||
listeners.current = { onStart, onFinish, onToolEnd };
|
||||
}, [onStart, onFinish, onToolEnd]);
|
||||
listeners.current = { onSend, onStart, onFinish, onToolEnd };
|
||||
}, [onSend, onStart, onFinish, onToolEnd]);
|
||||
|
||||
useEffect(() => {
|
||||
const normalizedThreadId = threadId ?? null;
|
||||
@ -175,45 +146,26 @@ export function useThreadStream({
|
||||
threadIdRef.current = normalizedThreadId;
|
||||
}, [threadId]);
|
||||
|
||||
const _handleOnStart = useCallback((id: string) => {
|
||||
const handleStreamStart = useCallback((_threadId: string, _runId: string) => {
|
||||
threadIdRef.current = _threadId;
|
||||
if (!startedRef.current) {
|
||||
listeners.current.onStart?.(id);
|
||||
listeners.current.onStart?.(_threadId, _runId);
|
||||
startedRef.current = true;
|
||||
}
|
||||
setOnStreamThreadId(_threadId);
|
||||
}, []);
|
||||
|
||||
const handleStreamStart = useCallback(
|
||||
(_threadId: string) => {
|
||||
threadIdRef.current = _threadId;
|
||||
_handleOnStart(_threadId);
|
||||
},
|
||||
[_handleOnStart],
|
||||
);
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
const updateSubtask = useUpdateSubtask();
|
||||
const runMetadataStorageRef = useRef<
|
||||
ReturnType<typeof getRunMetadataStorage> | undefined
|
||||
>(undefined);
|
||||
|
||||
if (
|
||||
typeof window !== "undefined" &&
|
||||
runMetadataStorageRef.current === undefined
|
||||
) {
|
||||
runMetadataStorageRef.current = getRunMetadataStorage();
|
||||
}
|
||||
|
||||
const thread = useStream<AgentThreadState>({
|
||||
client: getAPIClient(isMock),
|
||||
assistantId: "lead_agent",
|
||||
threadId: onStreamThreadId,
|
||||
reconnectOnMount: runMetadataStorageRef.current
|
||||
? () => runMetadataStorageRef.current!
|
||||
: false,
|
||||
reconnectOnMount: true,
|
||||
fetchStateHistory: { limit: 1 },
|
||||
onCreated(meta) {
|
||||
handleStreamStart(meta.thread_id);
|
||||
setOnStreamThreadId(meta.thread_id);
|
||||
handleStreamStart(meta.thread_id, meta.run_id);
|
||||
if (context.agent_name && !isMock) {
|
||||
void getAPIClient()
|
||||
.threads.update(meta.thread_id, {
|
||||
@ -231,6 +183,34 @@ export function useThreadStream({
|
||||
}
|
||||
},
|
||||
onUpdateEvent(data) {
|
||||
if (data["SummarizationMiddleware.before_model"]) {
|
||||
const _messages = [
|
||||
...(data["SummarizationMiddleware.before_model"].messages ?? []),
|
||||
];
|
||||
|
||||
if (_messages.length < 2) {
|
||||
return;
|
||||
}
|
||||
for (const m of _messages) {
|
||||
if (m.name === "summary" && m.type === "human") {
|
||||
summarizedRef.current?.add(m.id ?? "");
|
||||
}
|
||||
}
|
||||
const _lastKeepMessage = _messages[2];
|
||||
const _currentMessages = [...messagesRef.current];
|
||||
const _movedMessages: Message[] = [];
|
||||
for (const m of _currentMessages) {
|
||||
if (m.id !== undefined && m.id === _lastKeepMessage?.id) {
|
||||
break;
|
||||
}
|
||||
if (!summarizedRef.current?.has(m.id ?? "")) {
|
||||
_movedMessages.push(m);
|
||||
}
|
||||
}
|
||||
appendMessages(_movedMessages);
|
||||
messagesRef.current = [];
|
||||
}
|
||||
|
||||
const updates: Array<Partial<AgentThreadState> | null> = Object.values(
|
||||
data || {},
|
||||
);
|
||||
@ -295,9 +275,6 @@ export function useThreadStream({
|
||||
onFinish(state) {
|
||||
listeners.current.onFinish?.(state.values);
|
||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: ["thread-message-enrichment"],
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
@ -305,24 +282,25 @@ export function useThreadStream({
|
||||
const [optimisticMessages, setOptimisticMessages] = useState<Message[]>([]);
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const sendInFlightRef = useRef(false);
|
||||
const messagesRef = useRef<Message[]>([]);
|
||||
const summarizedRef = useRef<Set<string>>(null);
|
||||
// Track message count before sending so we know when server has responded
|
||||
const prevMsgCountRef = useRef(thread.messages.length);
|
||||
|
||||
summarizedRef.current ??= new Set<string>();
|
||||
|
||||
// Reset thread-local pending UI state when switching between threads so
|
||||
// optimistic messages and in-flight guards do not leak across chat views.
|
||||
useEffect(() => {
|
||||
startedRef.current = false;
|
||||
sendInFlightRef.current = false;
|
||||
prevMsgCountRef.current = 0;
|
||||
setOptimisticMessages([]);
|
||||
setIsUploading(false);
|
||||
}, [threadId]);
|
||||
|
||||
// Clear optimistic when server messages arrive (count increases)
|
||||
useEffect(() => {
|
||||
if (
|
||||
optimisticMessages.length > 0 &&
|
||||
thread.messages.length > prevMsgCountRef.current + 1
|
||||
thread.messages.length > prevMsgCountRef.current
|
||||
) {
|
||||
setOptimisticMessages([]);
|
||||
}
|
||||
@ -381,12 +359,7 @@ export function useThreadStream({
|
||||
}
|
||||
setOptimisticMessages(newOptimistic);
|
||||
|
||||
// Only fire onStart immediately for an existing persisted thread.
|
||||
// Brand-new chats should wait for onCreated(meta.thread_id) so URL sync
|
||||
// uses the real server-generated thread id.
|
||||
if (threadIdRef.current) {
|
||||
_handleOnStart(threadId);
|
||||
}
|
||||
listeners.current.onSend?.(threadId);
|
||||
|
||||
let uploadedFileInfo: UploadedFileInfo[] = [];
|
||||
|
||||
@ -520,19 +493,106 @@ export function useThreadStream({
|
||||
sendInFlightRef.current = false;
|
||||
}
|
||||
},
|
||||
[thread, _handleOnStart, t.uploads.uploadingFiles, context, queryClient],
|
||||
[thread, t.uploads.uploadingFiles, context, queryClient],
|
||||
);
|
||||
|
||||
// Merge thread with optimistic messages for display
|
||||
const mergedThread =
|
||||
optimisticMessages.length > 0
|
||||
? ({
|
||||
...thread,
|
||||
messages: [...thread.messages, ...optimisticMessages],
|
||||
} as typeof thread)
|
||||
: thread;
|
||||
// Cache the latest thread messages in a ref to compare against incoming history messages for deduplication,
|
||||
// and to allow access to the full message list in onUpdateEvent without causing re-renders.
|
||||
if (thread.messages.length >= messagesRef.current.length) {
|
||||
messagesRef.current = thread.messages;
|
||||
}
|
||||
|
||||
return [mergedThread, sendMessage, isUploading] as const;
|
||||
const mergedMessages = mergeMessages(
|
||||
history,
|
||||
thread.messages,
|
||||
optimisticMessages,
|
||||
);
|
||||
|
||||
// Merge history, live stream, and optimistic messages for display
|
||||
// History messages may overlap with thread.messages; thread.messages take precedence
|
||||
const mergedThread = {
|
||||
...thread,
|
||||
messages: mergedMessages,
|
||||
} as typeof thread;
|
||||
|
||||
return {
|
||||
thread: mergedThread,
|
||||
sendMessage,
|
||||
isUploading,
|
||||
isHistoryLoading,
|
||||
hasMoreHistory,
|
||||
loadMoreHistory,
|
||||
} as const;
|
||||
}
|
||||
|
||||
export function useThreadHistory(threadId: string) {
|
||||
const runs = useThreadRuns(threadId);
|
||||
const threadIdRef = useRef(threadId);
|
||||
const runsRef = useRef(runs.data ?? []);
|
||||
const indexRef = useRef(-1);
|
||||
const loadingRef = useRef(false);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
|
||||
loadingRef.current = loading;
|
||||
const loadMessages = useCallback(async () => {
|
||||
if (runsRef.current.length === 0) {
|
||||
return;
|
||||
}
|
||||
const run = runsRef.current[indexRef.current];
|
||||
if (!run || loadingRef.current) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
setLoading(true);
|
||||
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`,
|
||||
{
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
credentials: "include",
|
||||
},
|
||||
).then((res) => {
|
||||
return res.json();
|
||||
});
|
||||
const _messages = result.data
|
||||
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
||||
.map((m) => m.content);
|
||||
setMessages((prev) => [..._messages, ...prev]);
|
||||
indexRef.current -= 1;
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, []);
|
||||
useEffect(() => {
|
||||
threadIdRef.current = threadId;
|
||||
if (runs.data && runs.data.length > 0) {
|
||||
runsRef.current = runs.data ?? [];
|
||||
indexRef.current = runs.data.length - 1;
|
||||
}
|
||||
loadMessages().catch(() => {
|
||||
toast.error("Failed to load thread history.");
|
||||
});
|
||||
}, [threadId, runs.data, loadMessages]);
|
||||
|
||||
const appendMessages = useCallback((_messages: Message[]) => {
|
||||
setMessages((prev) => {
|
||||
return [...prev, ..._messages];
|
||||
});
|
||||
}, []);
|
||||
const hasMore = indexRef.current >= 0 || !runs.data;
|
||||
return {
|
||||
runs: runs.data,
|
||||
messages,
|
||||
loading,
|
||||
appendMessages,
|
||||
hasMore,
|
||||
loadMore: loadMessages,
|
||||
};
|
||||
}
|
||||
|
||||
export function useThreads(
|
||||
@ -602,6 +662,33 @@ export function useThreads(
|
||||
});
|
||||
}
|
||||
|
||||
export function useThreadRuns(threadId?: string) {
|
||||
const apiClient = getAPIClient();
|
||||
return useQuery<Run[]>({
|
||||
queryKey: ["thread", threadId],
|
||||
queryFn: async () => {
|
||||
if (!threadId) {
|
||||
return [];
|
||||
}
|
||||
const response = await apiClient.runs.list(threadId);
|
||||
return response;
|
||||
},
|
||||
refetchOnWindowFocus: false,
|
||||
});
|
||||
}
|
||||
|
||||
export function useRunDetail(threadId: string, runId: string) {
|
||||
const apiClient = getAPIClient();
|
||||
return useQuery<Run>({
|
||||
queryKey: ["thread", threadId, "run", runId],
|
||||
queryFn: async () => {
|
||||
const response = await apiClient.runs.get(threadId, runId);
|
||||
return response;
|
||||
},
|
||||
refetchOnWindowFocus: false,
|
||||
});
|
||||
}
|
||||
|
||||
export function useDeleteThread() {
|
||||
const queryClient = useQueryClient();
|
||||
const apiClient = getAPIClient();
|
||||
@ -609,7 +696,7 @@ export function useDeleteThread() {
|
||||
mutationFn: async ({ threadId }: { threadId: string }) => {
|
||||
await apiClient.threads.delete(threadId);
|
||||
|
||||
const response = await fetchWithAuth(
|
||||
const response = await fetch(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
@ -682,65 +769,3 @@ export function useRenameThread() {
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/** Per-message enrichment data attached by the backend ``/history`` helper. */
|
||||
export interface MessageEnrichment {
|
||||
run_id: string;
|
||||
/** ``undefined`` = not feedback-eligible; ``null`` = eligible but unrated. */
|
||||
feedback?: FeedbackData | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch ``/history`` once and index feedback + run_id by message id.
|
||||
*
|
||||
* Replaces the old ``useThreadFeedback`` hook which keyed by AI-message
|
||||
* ordinal position — an inherently fragile mapping that broke whenever
|
||||
* ``ai_tool_call`` messages were interleaved with ``ai_message`` messages.
|
||||
* Keying by ``message.id`` is stable regardless of run count, tool-call
|
||||
* chains, or summarization.
|
||||
*
|
||||
* The ``/history`` response is refreshed on every stream completion via
|
||||
* ``invalidateQueries(["thread-message-enrichment"])`` in ``onFinish``.
|
||||
*/
|
||||
export function useThreadMessageEnrichment(
|
||||
threadId: string | null | undefined,
|
||||
) {
|
||||
return useQuery({
|
||||
queryKey: ["thread-message-enrichment", threadId],
|
||||
queryFn: async (): Promise<Map<string, MessageEnrichment>> => {
|
||||
const empty = new Map<string, MessageEnrichment>();
|
||||
if (!threadId) return empty;
|
||||
const res = await fetchWithAuth(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/history`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ limit: 1 }),
|
||||
},
|
||||
);
|
||||
if (!res.ok) return empty;
|
||||
const entries = (await res.json()) as Array<{
|
||||
values?: {
|
||||
messages?: Array<{
|
||||
id?: string;
|
||||
run_id?: string;
|
||||
feedback?: FeedbackData | null;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
const messages = entries[0]?.values?.messages ?? [];
|
||||
const map = new Map<string, MessageEnrichment>();
|
||||
for (const m of messages) {
|
||||
if (!m.id || !m.run_id) continue;
|
||||
const entry: MessageEnrichment = { run_id: m.run_id };
|
||||
// Preserve presence: "feedback" key absent → ineligible; present with
|
||||
// null → eligible but unrated; present with object → rated.
|
||||
if ("feedback" in m) entry.feedback = m.feedback;
|
||||
map.set(m.id, entry);
|
||||
}
|
||||
return map;
|
||||
},
|
||||
enabled: !!threadId,
|
||||
staleTime: 30_000,
|
||||
});
|
||||
}
|
||||
|
||||
@ -22,3 +22,12 @@ export interface AgentThreadContext extends Record<string, unknown> {
|
||||
export interface AgentThread extends Thread<AgentThreadState> {
|
||||
context?: AgentThreadContext;
|
||||
}
|
||||
|
||||
export interface RunMessage {
|
||||
run_id: string;
|
||||
content: Message;
|
||||
metadata: {
|
||||
caller: string;
|
||||
};
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
* API functions for file uploads
|
||||
*/
|
||||
|
||||
import { fetchWithAuth } from "../api/fetcher";
|
||||
import { fetch } from "../api/fetcher";
|
||||
import { getBackendBaseURL } from "../config";
|
||||
|
||||
export interface UploadedFileInfo {
|
||||
@ -51,7 +51,7 @@ export async function uploadFiles(
|
||||
formData.append("files", file);
|
||||
});
|
||||
|
||||
const response = await fetchWithAuth(
|
||||
const response = await fetch(
|
||||
`${getBackendBaseURL()}/api/threads/${threadId}/uploads`,
|
||||
{
|
||||
method: "POST",
|
||||
@ -92,7 +92,7 @@ export async function deleteUploadedFile(
|
||||
threadId: string,
|
||||
filename: string,
|
||||
): Promise<{ success: boolean; message: string }> {
|
||||
const response = await fetchWithAuth(
|
||||
const response = await fetch(
|
||||
`${getBackendBaseURL()}/api/threads/${threadId}/uploads/${filename}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user