mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-12 19:53:40 +00:00
fix: use backend thread token usage for header total (#2800)
* fix: use backend thread token usage for header total * Refactor thread token usage fetch
This commit is contained in:
parent
881ff71252
commit
417416087b
@ -68,6 +68,27 @@ class RunResponse(BaseModel):
|
||||
updated_at: str = ""
|
||||
|
||||
|
||||
class ThreadTokenUsageModelBreakdown(BaseModel):
|
||||
tokens: int = 0
|
||||
runs: int = 0
|
||||
|
||||
|
||||
class ThreadTokenUsageCallerBreakdown(BaseModel):
|
||||
lead_agent: int = 0
|
||||
subagent: int = 0
|
||||
middleware: int = 0
|
||||
|
||||
|
||||
class ThreadTokenUsageResponse(BaseModel):
|
||||
thread_id: str
|
||||
total_tokens: int = 0
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_runs: int = 0
|
||||
by_model: dict[str, ThreadTokenUsageModelBreakdown] = Field(default_factory=dict)
|
||||
by_caller: ThreadTokenUsageCallerBreakdown = Field(default_factory=ThreadTokenUsageCallerBreakdown)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -368,10 +389,10 @@ async def list_run_events(
|
||||
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/token-usage")
|
||||
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> dict:
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
|
||||
"""Thread-level token usage aggregation."""
|
||||
run_store = get_run_store(request)
|
||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
||||
return {"thread_id": thread_id, **agg}
|
||||
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
|
||||
|
||||
@ -166,6 +166,61 @@ class TestRunRepository:
|
||||
assert row["total_tokens"] == 100
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("success-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"success-run",
|
||||
status="success",
|
||||
total_input_tokens=70,
|
||||
total_output_tokens=30,
|
||||
total_tokens=100,
|
||||
lead_agent_tokens=80,
|
||||
subagent_tokens=15,
|
||||
middleware_tokens=5,
|
||||
)
|
||||
await repo.put("error-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"error-run",
|
||||
status="error",
|
||||
total_input_tokens=20,
|
||||
total_output_tokens=30,
|
||||
total_tokens=50,
|
||||
lead_agent_tokens=40,
|
||||
subagent_tokens=10,
|
||||
)
|
||||
await repo.put("running-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"running-run",
|
||||
status="running",
|
||||
total_input_tokens=900,
|
||||
total_output_tokens=99,
|
||||
total_tokens=999,
|
||||
lead_agent_tokens=999,
|
||||
)
|
||||
await repo.put("other-thread-run", thread_id="t2", status="running")
|
||||
await repo.update_run_completion(
|
||||
"other-thread-run",
|
||||
status="success",
|
||||
total_tokens=888,
|
||||
lead_agent_tokens=888,
|
||||
)
|
||||
|
||||
agg = await repo.aggregate_tokens_by_thread("t1")
|
||||
|
||||
assert agg["total_tokens"] == 150
|
||||
assert agg["total_input_tokens"] == 90
|
||||
assert agg["total_output_tokens"] == 60
|
||||
assert agg["total_runs"] == 2
|
||||
assert agg["by_model"] == {"unknown": {"tokens": 150, "runs": 2}}
|
||||
assert agg["by_caller"] == {
|
||||
"lead_agent": 120,
|
||||
"subagent": 25,
|
||||
"middleware": 5,
|
||||
}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_ordered_desc(self, tmp_path):
|
||||
"""list_by_thread returns newest first."""
|
||||
|
||||
55
backend/tests/test_thread_token_usage.py
Normal file
55
backend/tests/test_thread_token_usage.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Tests for thread-level token usage aggregation API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import thread_runs
|
||||
|
||||
|
||||
def _make_app(run_store: MagicMock):
|
||||
app = make_authed_test_app()
|
||||
app.include_router(thread_runs.router)
|
||||
app.state.run_store = run_store
|
||||
return app
|
||||
|
||||
|
||||
def test_thread_token_usage_returns_stable_shape():
|
||||
run_store = MagicMock()
|
||||
run_store.aggregate_tokens_by_thread = AsyncMock(
|
||||
return_value={
|
||||
"total_tokens": 150,
|
||||
"total_input_tokens": 90,
|
||||
"total_output_tokens": 60,
|
||||
"total_runs": 2,
|
||||
"by_model": {"unknown": {"tokens": 150, "runs": 2}},
|
||||
"by_caller": {
|
||||
"lead_agent": 120,
|
||||
"subagent": 25,
|
||||
"middleware": 5,
|
||||
},
|
||||
},
|
||||
)
|
||||
app = _make_app(run_store)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-1/token-usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"thread_id": "thread-1",
|
||||
"total_tokens": 150,
|
||||
"total_input_tokens": 90,
|
||||
"total_output_tokens": 60,
|
||||
"total_runs": 2,
|
||||
"by_model": {"unknown": {"tokens": 150, "runs": 2}},
|
||||
"by_caller": {
|
||||
"lead_agent": 120,
|
||||
"subagent": 25,
|
||||
"middleware": 5,
|
||||
},
|
||||
}
|
||||
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
|
||||
@ -26,7 +26,8 @@ import { useI18n } from "@/core/i18n/hooks";
|
||||
import { useModels } from "@/core/models/hooks";
|
||||
import { useNotification } from "@/core/notification/hooks";
|
||||
import { useLocalSettings, useThreadSettings } from "@/core/settings";
|
||||
import { useThreadStream } from "@/core/threads/hooks";
|
||||
import { useThreadStream, useThreadTokenUsage } from "@/core/threads/hooks";
|
||||
import { threadTokenUsageToTokenUsage } from "@/core/threads/token-usage";
|
||||
import { textOfMessage } from "@/core/threads/utils";
|
||||
import { env } from "@/env";
|
||||
import { cn } from "@/lib/utils";
|
||||
@ -42,15 +43,21 @@ export default function AgentChatPage() {
|
||||
|
||||
const { agent } = useAgent(agent_name);
|
||||
|
||||
const { threadId, setThreadId, isNewThread, setIsNewThread } =
|
||||
const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } =
|
||||
useThreadChat();
|
||||
const [settings, setSettings] = useThreadSettings(threadId);
|
||||
const [localSettings, setLocalSettings] = useLocalSettings();
|
||||
const { tokenUsageEnabled } = useModels();
|
||||
const threadTokenUsage = useThreadTokenUsage(
|
||||
isNewThread || isMock ? undefined : threadId,
|
||||
{ enabled: tokenUsageEnabled && !isMock },
|
||||
);
|
||||
const backendTokenUsage = threadTokenUsageToTokenUsage(threadTokenUsage.data);
|
||||
|
||||
const { showNotification } = useNotification();
|
||||
const {
|
||||
thread,
|
||||
pendingUsageMessages,
|
||||
sendMessage,
|
||||
isHistoryLoading,
|
||||
hasMoreHistory,
|
||||
@ -58,6 +65,7 @@ export default function AgentChatPage() {
|
||||
} = useThreadStream({
|
||||
threadId: isNewThread ? undefined : threadId,
|
||||
context: { ...settings.context, agent_name: agent_name },
|
||||
isMock,
|
||||
onStart: (createdThreadId) => {
|
||||
setThreadId(createdThreadId);
|
||||
setIsNewThread(false);
|
||||
@ -141,8 +149,11 @@ export default function AgentChatPage() {
|
||||
</Button>
|
||||
</Tooltip>
|
||||
<TokenUsageIndicator
|
||||
threadId={isNewThread ? undefined : threadId}
|
||||
backendUsage={backendTokenUsage}
|
||||
enabled={tokenUsageEnabled}
|
||||
messages={thread.messages}
|
||||
pendingMessages={pendingUsageMessages}
|
||||
preferences={localSettings.tokenUsage}
|
||||
onPreferencesChange={(preferences) =>
|
||||
setLocalSettings("tokenUsage", preferences)
|
||||
|
||||
@ -25,7 +25,8 @@ import { useI18n } from "@/core/i18n/hooks";
|
||||
import { useModels } from "@/core/models/hooks";
|
||||
import { useNotification } from "@/core/notification/hooks";
|
||||
import { useLocalSettings, useThreadSettings } from "@/core/settings";
|
||||
import { useThreadStream } from "@/core/threads/hooks";
|
||||
import { useThreadStream, useThreadTokenUsage } from "@/core/threads/hooks";
|
||||
import { threadTokenUsageToTokenUsage } from "@/core/threads/token-usage";
|
||||
import { textOfMessage } from "@/core/threads/utils";
|
||||
import { env } from "@/env";
|
||||
import { cn } from "@/lib/utils";
|
||||
@ -44,6 +45,11 @@ export default function ChatPage() {
|
||||
const [settings, setSettings] = useThreadSettings(threadId);
|
||||
const [localSettings, setLocalSettings] = useLocalSettings();
|
||||
const { tokenUsageEnabled } = useModels();
|
||||
const threadTokenUsage = useThreadTokenUsage(
|
||||
isNewThread || isMock ? undefined : threadId,
|
||||
{ enabled: tokenUsageEnabled && !isMock },
|
||||
);
|
||||
const backendTokenUsage = threadTokenUsageToTokenUsage(threadTokenUsage.data);
|
||||
const mountedRef = useRef(false);
|
||||
useSpecificChatMode();
|
||||
|
||||
@ -63,6 +69,7 @@ export default function ChatPage() {
|
||||
|
||||
const {
|
||||
thread,
|
||||
pendingUsageMessages,
|
||||
sendMessage,
|
||||
isUploading,
|
||||
isHistoryLoading,
|
||||
@ -137,8 +144,11 @@ export default function ChatPage() {
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<TokenUsageIndicator
|
||||
threadId={isNewThread ? undefined : threadId}
|
||||
backendUsage={backendTokenUsage}
|
||||
enabled={tokenUsageEnabled}
|
||||
messages={thread.messages}
|
||||
pendingMessages={pendingUsageMessages}
|
||||
preferences={localSettings.tokenUsage}
|
||||
onPreferencesChange={(preferences) =>
|
||||
setLocalSettings("tokenUsage", preferences)
|
||||
|
||||
@ -15,7 +15,11 @@ import {
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { accumulateUsage, formatTokenCount } from "@/core/messages/usage";
|
||||
import {
|
||||
formatTokenCount,
|
||||
selectHeaderTokenUsage,
|
||||
type TokenUsage,
|
||||
} from "@/core/messages/usage";
|
||||
import {
|
||||
getTokenUsageViewPreset,
|
||||
tokenUsagePreferencesFromPreset,
|
||||
@ -25,7 +29,10 @@ import {
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface TokenUsageIndicatorProps {
|
||||
threadId?: string;
|
||||
messages: Message[];
|
||||
pendingMessages?: Message[];
|
||||
backendUsage?: TokenUsage | null;
|
||||
enabled?: boolean;
|
||||
preferences: TokenUsagePreferences;
|
||||
onPreferencesChange: (preferences: TokenUsagePreferences) => void;
|
||||
@ -33,7 +40,10 @@ interface TokenUsageIndicatorProps {
|
||||
}
|
||||
|
||||
export function TokenUsageIndicator({
|
||||
threadId,
|
||||
messages,
|
||||
pendingMessages,
|
||||
backendUsage,
|
||||
enabled = false,
|
||||
preferences,
|
||||
onPreferencesChange,
|
||||
@ -41,7 +51,15 @@ export function TokenUsageIndicator({
|
||||
}: TokenUsageIndicatorProps) {
|
||||
const { t } = useI18n();
|
||||
|
||||
const usage = useMemo(() => accumulateUsage(messages), [messages]);
|
||||
const usage = useMemo(
|
||||
() =>
|
||||
selectHeaderTokenUsage({
|
||||
backendUsage: threadId ? backendUsage : null,
|
||||
messages,
|
||||
pendingMessages,
|
||||
}),
|
||||
[backendUsage, messages, pendingMessages, threadId],
|
||||
);
|
||||
const preset = getTokenUsageViewPreset(preferences);
|
||||
|
||||
if (!enabled) {
|
||||
|
||||
@ -310,7 +310,7 @@ export const enUS: Translations = {
|
||||
unavailable:
|
||||
"No token usage yet. Usage appears only after a successful model response when the provider returns usage_metadata.",
|
||||
unavailableShort: "No usage returned",
|
||||
note: "Shown from provider-returned usage_metadata. Totals are best-effort conversation totals and may differ from provider billing pages.",
|
||||
note: "Header totals use persisted thread usage when available. Per-turn and debug usage come from visible messages. Totals may differ from provider billing pages.",
|
||||
presets: {
|
||||
off: "Off",
|
||||
summary: "Summary",
|
||||
|
||||
@ -296,7 +296,7 @@ export const zhCN: Translations = {
|
||||
unavailable:
|
||||
"暂无 Token 用量。只有模型成功返回且供应商提供 usage_metadata 时才会显示。",
|
||||
unavailableShort: "未返回用量",
|
||||
note: "基于供应商返回的 usage_metadata 展示。当前总量是 best-effort 的会话参考值,可能与平台账单页不完全一致。",
|
||||
note: "顶部总量优先使用后端持久化的线程用量。每轮和调试用量来自当前可见消息,可能与平台账单页不完全一致。",
|
||||
presets: {
|
||||
off: "关闭",
|
||||
summary: "总览",
|
||||
|
||||
@ -65,6 +65,40 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null {
|
||||
return hasUsage ? cumulative : null;
|
||||
}
|
||||
|
||||
function hasNonZeroUsage(
|
||||
usage: TokenUsage | null | undefined,
|
||||
): usage is TokenUsage {
|
||||
return (
|
||||
usage !== null &&
|
||||
usage !== undefined &&
|
||||
(usage.inputTokens > 0 || usage.outputTokens > 0 || usage.totalTokens > 0)
|
||||
);
|
||||
}
|
||||
|
||||
function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
|
||||
return {
|
||||
inputTokens: base.inputTokens + delta.inputTokens,
|
||||
outputTokens: base.outputTokens + delta.outputTokens,
|
||||
totalTokens: base.totalTokens + delta.totalTokens,
|
||||
};
|
||||
}
|
||||
|
||||
export function selectHeaderTokenUsage({
|
||||
backendUsage,
|
||||
messages,
|
||||
pendingMessages = [],
|
||||
}: {
|
||||
backendUsage?: TokenUsage | null;
|
||||
messages: Message[];
|
||||
pendingMessages?: Message[];
|
||||
}): TokenUsage | null {
|
||||
if (hasNonZeroUsage(backendUsage)) {
|
||||
const pendingUsage = accumulateUsage(pendingMessages);
|
||||
return pendingUsage ? addUsage(backendUsage, pendingUsage) : backendUsage;
|
||||
}
|
||||
return accumulateUsage(messages);
|
||||
}
|
||||
|
||||
/**
|
||||
* Format a token count for display: 1234 -> "1,234", 12345 -> "12.3K"
|
||||
*/
|
||||
|
||||
24
frontend/src/core/threads/api.ts
Normal file
24
frontend/src/core/threads/api.ts
Normal file
@ -0,0 +1,24 @@
|
||||
import { fetch as fetchWithAuth } from "@/core/api/fetcher";
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
|
||||
import type { ThreadTokenUsageResponse } from "./types";
|
||||
|
||||
export async function fetchThreadTokenUsage(
|
||||
threadId: string,
|
||||
): Promise<ThreadTokenUsageResponse | null> {
|
||||
const response = await fetchWithAuth(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/token-usage`,
|
||||
{
|
||||
method: "GET",
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
if (response.status === 403 || response.status === 404) {
|
||||
return null;
|
||||
}
|
||||
throw new Error("Failed to load thread token usage.");
|
||||
}
|
||||
|
||||
return (await response.json()) as ThreadTokenUsageResponse;
|
||||
}
|
||||
@ -17,7 +17,14 @@ import { useUpdateSubtask } from "../tasks/context";
|
||||
import type { UploadedFileInfo } from "../uploads";
|
||||
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
|
||||
|
||||
import type { AgentThread, AgentThreadState, RunMessage } from "./types";
|
||||
import { fetchThreadTokenUsage } from "./api";
|
||||
import { threadTokenUsageQueryKey } from "./token-usage";
|
||||
import type {
|
||||
AgentThread,
|
||||
AgentThreadState,
|
||||
RunMessage,
|
||||
ThreadTokenUsageResponse,
|
||||
} from "./types";
|
||||
|
||||
export type ToolEndEvent = {
|
||||
name: string;
|
||||
@ -75,6 +82,23 @@ function mergeMessages(
|
||||
];
|
||||
}
|
||||
|
||||
function messageIdentity(message: Message): string | undefined {
|
||||
if ("tool_call_id" in message) {
|
||||
return message.tool_call_id;
|
||||
}
|
||||
return message.id;
|
||||
}
|
||||
|
||||
function getMessagesAfterBaseline(
|
||||
messages: Message[],
|
||||
baselineMessageIds: ReadonlySet<string>,
|
||||
): Message[] {
|
||||
return messages.filter((message) => {
|
||||
const id = messageIdentity(message);
|
||||
return !id || !baselineMessageIds.has(id);
|
||||
});
|
||||
}
|
||||
|
||||
function getStreamErrorMessage(error: unknown): string {
|
||||
if (typeof error === "string" && error.trim()) {
|
||||
return error;
|
||||
@ -114,6 +138,7 @@ export function useThreadStream({
|
||||
// and to allow access to the current thread id in onUpdateEvent
|
||||
const threadIdRef = useRef<string | null>(threadId ?? null);
|
||||
const startedRef = useRef(false);
|
||||
const pendingUsageBaselineMessageIdsRef = useRef<Set<string>>(new Set());
|
||||
const listeners = useRef({
|
||||
onSend,
|
||||
onStart,
|
||||
@ -271,29 +296,42 @@ export function useThreadStream({
|
||||
onError(error) {
|
||||
setOptimisticMessages([]);
|
||||
toast.error(getStreamErrorMessage(error));
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
if (threadIdRef.current && !isMock) {
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
||||
});
|
||||
}
|
||||
},
|
||||
onFinish(state) {
|
||||
listeners.current.onFinish?.(state.values);
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||
if (threadIdRef.current && !isMock) {
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// Optimistic messages shown before the server stream responds
|
||||
const [optimisticMessages, setOptimisticMessages] = useState<Message[]>([]);
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const humanMessageCount = thread.messages.filter(
|
||||
(m) => m.type === "human",
|
||||
).length;
|
||||
const latestMessageCountsRef = useRef({ humanMessageCount });
|
||||
const sendInFlightRef = useRef(false);
|
||||
const messagesRef = useRef<Message[]>([]);
|
||||
const summarizedRef = useRef<Set<string>>(null);
|
||||
// Track message count before sending so we know when server has responded
|
||||
const prevMsgCountRef = useRef(thread.messages.length);
|
||||
// Track human message count before sending to prevent clearing optimistic
|
||||
// messages before the server's human message arrives (e.g. when AI messages
|
||||
// from "messages-tuple" events arrive before the input human message from
|
||||
// "values" events).
|
||||
const prevHumanMsgCountRef = useRef(
|
||||
thread.messages.filter((m) => m.type === "human").length,
|
||||
);
|
||||
const prevHumanMsgCountRef = useRef(humanMessageCount);
|
||||
|
||||
latestMessageCountsRef.current = { humanMessageCount };
|
||||
summarizedRef.current ??= new Set<string>();
|
||||
|
||||
// Reset thread-local pending UI state when switching between threads so
|
||||
@ -301,31 +339,43 @@ export function useThreadStream({
|
||||
useEffect(() => {
|
||||
startedRef.current = false;
|
||||
sendInFlightRef.current = false;
|
||||
prevMsgCountRef.current = thread.messages.length;
|
||||
prevHumanMsgCountRef.current = thread.messages.filter(
|
||||
(m) => m.type === "human",
|
||||
).length;
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
prevHumanMsgCountRef.current =
|
||||
latestMessageCountsRef.current.humanMessageCount;
|
||||
}, [threadId]);
|
||||
|
||||
// When streaming starts without a baseline (e.g. reconnection, run started
|
||||
// from another client, or page reload mid-stream), snapshot the current
|
||||
// messages so only *new* messages are treated as "pending" for token usage.
|
||||
useEffect(() => {
|
||||
if (
|
||||
thread.isLoading &&
|
||||
pendingUsageBaselineMessageIdsRef.current.size === 0
|
||||
) {
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
||||
thread.messages
|
||||
.map(messageIdentity)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
}
|
||||
}, [thread.isLoading, thread.messages]);
|
||||
|
||||
// Clear optimistic when server messages arrive.
|
||||
// For messages with a human optimistic message, wait until the server's
|
||||
// human message has arrived to avoid clearing before the input message
|
||||
// appears in the stream (the input message may arrive via "values" events
|
||||
// after individual "messages-tuple" events for AI messages).
|
||||
const optimisticMessageCount = optimisticMessages.length;
|
||||
const hasHumanOptimistic = optimisticMessages.some((m) => m.type === "human");
|
||||
useEffect(() => {
|
||||
if (optimisticMessages.length === 0) return;
|
||||
if (optimisticMessageCount === 0) return;
|
||||
|
||||
const hasHumanOptimistic = optimisticMessages.some(
|
||||
(m) => m.type === "human",
|
||||
);
|
||||
const newHumanMsgArrived =
|
||||
thread.messages.filter((m) => m.type === "human").length >
|
||||
prevHumanMsgCountRef.current;
|
||||
const newHumanMsgArrived = humanMessageCount > prevHumanMsgCountRef.current;
|
||||
|
||||
if (!hasHumanOptimistic || newHumanMsgArrived) {
|
||||
setOptimisticMessages([]);
|
||||
}
|
||||
}, [thread.messages.length, optimisticMessages.length]);
|
||||
}, [hasHumanOptimistic, humanMessageCount, optimisticMessageCount]);
|
||||
|
||||
const sendMessage = useCallback(
|
||||
async (
|
||||
@ -341,11 +391,14 @@ export function useThreadStream({
|
||||
|
||||
const text = message.text.trim();
|
||||
|
||||
// Capture current count before showing optimistic messages
|
||||
prevMsgCountRef.current = thread.messages.length;
|
||||
prevHumanMsgCountRef.current = thread.messages.filter(
|
||||
(m) => m.type === "human",
|
||||
).length;
|
||||
// Capture the current human message count before showing optimistic
|
||||
// messages so we can wait for the server's copy of the user input.
|
||||
prevHumanMsgCountRef.current = humanMessageCount;
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
||||
thread.messages
|
||||
.map(messageIdentity)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
|
||||
// Build optimistic files list with uploading status
|
||||
const optimisticFiles: FileInMessage[] = (message.files ?? []).map(
|
||||
@ -517,7 +570,7 @@ export function useThreadStream({
|
||||
sendInFlightRef.current = false;
|
||||
}
|
||||
},
|
||||
[thread, t.uploads.uploadingFiles, context, queryClient],
|
||||
[thread, t.uploads.uploadingFiles, context, queryClient, humanMessageCount],
|
||||
);
|
||||
|
||||
// Cache the latest thread messages in a ref to compare against incoming history messages for deduplication,
|
||||
@ -531,6 +584,12 @@ export function useThreadStream({
|
||||
thread.messages,
|
||||
optimisticMessages,
|
||||
);
|
||||
const pendingUsageMessages = thread.isLoading
|
||||
? getMessagesAfterBaseline(
|
||||
thread.messages,
|
||||
pendingUsageBaselineMessageIdsRef.current,
|
||||
)
|
||||
: [];
|
||||
|
||||
// Merge history, live stream, and optimistic messages for display
|
||||
// History messages may overlap with thread.messages; thread.messages take precedence
|
||||
@ -541,6 +600,7 @@ export function useThreadStream({
|
||||
|
||||
return {
|
||||
thread: mergedThread,
|
||||
pendingUsageMessages,
|
||||
sendMessage,
|
||||
isUploading,
|
||||
isHistoryLoading,
|
||||
@ -701,6 +761,24 @@ export function useThreadRuns(threadId?: string) {
|
||||
});
|
||||
}
|
||||
|
||||
export function useThreadTokenUsage(
|
||||
threadId?: string | null,
|
||||
{ enabled = true }: { enabled?: boolean } = {},
|
||||
) {
|
||||
return useQuery<ThreadTokenUsageResponse | null>({
|
||||
queryKey: threadTokenUsageQueryKey(threadId),
|
||||
queryFn: async () => {
|
||||
if (!threadId) {
|
||||
return null;
|
||||
}
|
||||
return fetchThreadTokenUsage(threadId);
|
||||
},
|
||||
enabled: enabled && Boolean(threadId),
|
||||
retry: false,
|
||||
refetchOnWindowFocus: false,
|
||||
});
|
||||
}
|
||||
|
||||
export function useRunDetail(threadId: string, runId: string) {
|
||||
const apiClient = getAPIClient();
|
||||
return useQuery<Run>({
|
||||
|
||||
20
frontend/src/core/threads/token-usage.ts
Normal file
20
frontend/src/core/threads/token-usage.ts
Normal file
@ -0,0 +1,20 @@
|
||||
import type { TokenUsage } from "@/core/messages/usage";
|
||||
|
||||
import type { ThreadTokenUsageResponse } from "./types";
|
||||
|
||||
export function threadTokenUsageQueryKey(threadId?: string | null) {
|
||||
return ["thread-token-usage", threadId] as const;
|
||||
}
|
||||
|
||||
export function threadTokenUsageToTokenUsage(
|
||||
usage: ThreadTokenUsageResponse | null | undefined,
|
||||
): TokenUsage | null {
|
||||
if (!usage) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
inputTokens: usage.total_input_tokens ?? 0,
|
||||
outputTokens: usage.total_output_tokens ?? 0,
|
||||
totalTokens: usage.total_tokens ?? 0,
|
||||
};
|
||||
}
|
||||
@ -31,3 +31,17 @@ export interface RunMessage {
|
||||
};
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
export interface ThreadTokenUsageResponse {
|
||||
thread_id: string;
|
||||
total_tokens: number;
|
||||
total_input_tokens: number;
|
||||
total_output_tokens: number;
|
||||
total_runs: number;
|
||||
by_model: Record<string, { tokens: number; runs: number }>;
|
||||
by_caller: {
|
||||
lead_agent: number;
|
||||
subagent: number;
|
||||
middleware: number;
|
||||
};
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
import { expect, test } from "vitest";
|
||||
|
||||
import { accumulateUsage } from "@/core/messages/usage";
|
||||
import { accumulateUsage, selectHeaderTokenUsage } from "@/core/messages/usage";
|
||||
import {
|
||||
getAssistantTurnUsageMessages,
|
||||
getMessageGroups,
|
||||
@ -79,3 +79,86 @@ test("keeps header and per-turn aggregation consistent for duplicated UI groups"
|
||||
totalTokens: 27,
|
||||
});
|
||||
});
|
||||
|
||||
test("prefers backend thread usage for header totals", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "ai-visible",
|
||||
type: "ai",
|
||||
content: "Visible answer",
|
||||
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
expect(
|
||||
selectHeaderTokenUsage({
|
||||
backendUsage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 },
|
||||
messages,
|
||||
}),
|
||||
).toEqual({
|
||||
inputTokens: 100,
|
||||
outputTokens: 50,
|
||||
totalTokens: 150,
|
||||
});
|
||||
});
|
||||
|
||||
test("adds current in-flight message usage to backend header totals", () => {
|
||||
const completedMessages = [
|
||||
{
|
||||
id: "ai-completed",
|
||||
type: "ai",
|
||||
content: "Completed answer",
|
||||
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
|
||||
},
|
||||
{
|
||||
id: "ai-pending",
|
||||
type: "ai",
|
||||
content: "Streaming answer",
|
||||
usage_metadata: { input_tokens: 4, output_tokens: 6, total_tokens: 10 },
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
expect(
|
||||
selectHeaderTokenUsage({
|
||||
backendUsage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 },
|
||||
messages: completedMessages,
|
||||
pendingMessages: [completedMessages[1]!],
|
||||
}),
|
||||
).toEqual({
|
||||
inputTokens: 104,
|
||||
outputTokens: 56,
|
||||
totalTokens: 160,
|
||||
});
|
||||
});
|
||||
|
||||
test("falls back to visible messages when backend usage is unavailable or zero", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "ai-visible",
|
||||
type: "ai",
|
||||
content: "Visible answer",
|
||||
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
expect(
|
||||
selectHeaderTokenUsage({
|
||||
backendUsage: null,
|
||||
messages,
|
||||
}),
|
||||
).toEqual({
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
totalTokens: 15,
|
||||
});
|
||||
expect(
|
||||
selectHeaderTokenUsage({
|
||||
backendUsage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 },
|
||||
messages,
|
||||
}),
|
||||
).toEqual({
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
totalTokens: 15,
|
||||
});
|
||||
});
|
||||
|
||||
51
frontend/tests/unit/core/threads/api.test.ts
Normal file
51
frontend/tests/unit/core/threads/api.test.ts
Normal file
@ -0,0 +1,51 @@
|
||||
import { beforeEach, expect, test, vi } from "vitest";
|
||||
|
||||
const fetchWithAuth = vi.fn();
|
||||
|
||||
vi.mock("@/core/api/fetcher", () => ({
|
||||
fetch: fetchWithAuth,
|
||||
}));
|
||||
|
||||
beforeEach(() => {
|
||||
fetchWithAuth.mockReset();
|
||||
});
|
||||
|
||||
test("fetchThreadTokenUsage uses shared auth fetch without JSON GET headers", async () => {
|
||||
fetchWithAuth.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
thread_id: "thread-1",
|
||||
total_input_tokens: 3,
|
||||
total_output_tokens: 4,
|
||||
total_tokens: 7,
|
||||
total_runs: 1,
|
||||
by_model: { unknown: { tokens: 7, runs: 1 } },
|
||||
by_caller: {},
|
||||
}),
|
||||
});
|
||||
|
||||
const { fetchThreadTokenUsage } = await import("@/core/threads/api");
|
||||
|
||||
await expect(fetchThreadTokenUsage("thread-1")).resolves.toMatchObject({
|
||||
thread_id: "thread-1",
|
||||
total_tokens: 7,
|
||||
});
|
||||
|
||||
expect(fetchWithAuth).toHaveBeenCalledWith(
|
||||
expect.stringContaining("/api/threads/thread-1/token-usage"),
|
||||
{
|
||||
method: "GET",
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
test("fetchThreadTokenUsage returns null for unavailable token usage", async () => {
|
||||
fetchWithAuth.mockResolvedValue({
|
||||
ok: false,
|
||||
status: 404,
|
||||
});
|
||||
|
||||
const { fetchThreadTokenUsage } = await import("@/core/threads/api");
|
||||
|
||||
await expect(fetchThreadTokenUsage("thread-1")).resolves.toBeNull();
|
||||
});
|
||||
31
frontend/tests/unit/core/threads/token-usage.test.ts
Normal file
31
frontend/tests/unit/core/threads/token-usage.test.ts
Normal file
@ -0,0 +1,31 @@
|
||||
import { expect, test } from "vitest";
|
||||
|
||||
import { threadTokenUsageToTokenUsage } from "@/core/threads/token-usage";
|
||||
import type { ThreadTokenUsageResponse } from "@/core/threads/types";
|
||||
|
||||
test("maps backend thread token usage to UI token usage", () => {
|
||||
const response: ThreadTokenUsageResponse = {
|
||||
thread_id: "thread-1",
|
||||
total_input_tokens: 90,
|
||||
total_output_tokens: 60,
|
||||
total_tokens: 150,
|
||||
total_runs: 2,
|
||||
by_model: { unknown: { tokens: 150, runs: 2 } },
|
||||
by_caller: {
|
||||
lead_agent: 120,
|
||||
subagent: 25,
|
||||
middleware: 5,
|
||||
},
|
||||
};
|
||||
|
||||
expect(threadTokenUsageToTokenUsage(response)).toEqual({
|
||||
inputTokens: 90,
|
||||
outputTokens: 60,
|
||||
totalTokens: 150,
|
||||
});
|
||||
});
|
||||
|
||||
test("returns null when backend thread token usage is unavailable", () => {
|
||||
expect(threadTokenUsageToTokenUsage(null)).toBeNull();
|
||||
expect(threadTokenUsageToTokenUsage(undefined)).toBeNull();
|
||||
});
|
||||
Loading…
x
Reference in New Issue
Block a user