fix: dedupe token usage aggregation by message id (#2770)

This commit is contained in:
YuJitang 2026-05-08 09:54:20 +08:00 committed by GitHub
parent 6c220a9aef
commit 530bda7107
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 103 additions and 6 deletions

View File

@ -28,7 +28,12 @@ export function getUsageMetadata(message: Message): TokenUsage | null {
}
/**
* Accumulate token usage across all AI messages in a thread.
* Accumulate token usage across AI messages.
*
* UI rendering may place the same AI message in more than one group, such as
* when a message contains both reasoning and final answer content. Token usage
* is attached to the AI message itself, so a message id should only contribute
* once to any aggregate.
*/
export function accumulateUsage(messages: Message[]): TokenUsage | null {
const cumulative: TokenUsage = {
@ -37,14 +42,25 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null {
totalTokens: 0,
};
let hasUsage = false;
const countedMessageIds = new Set<string>();
for (const message of messages) {
const usage = getUsageMetadata(message);
if (usage) {
hasUsage = true;
cumulative.inputTokens += usage.inputTokens;
cumulative.outputTokens += usage.outputTokens;
cumulative.totalTokens += usage.totalTokens;
if (!usage) {
continue;
}
if (message.id) {
if (countedMessageIds.has(message.id)) {
continue;
}
countedMessageIds.add(message.id);
}
hasUsage = true;
cumulative.inputTokens += usage.inputTokens;
cumulative.outputTokens += usage.outputTokens;
cumulative.totalTokens += usage.totalTokens;
}
return hasUsage ? cumulative : null;
}

View File

@ -0,0 +1,81 @@
import type { Message } from "@langchain/langgraph-sdk";
import { expect, test } from "vitest";
import { accumulateUsage } from "@/core/messages/usage";
import {
getAssistantTurnUsageMessages,
getMessageGroups,
} from "@/core/messages/utils";
test("accumulates each AI message usage only once by message id", () => {
const aiMessage = {
id: "ai-1",
type: "ai",
content: "Answer",
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
} as Message;
expect(accumulateUsage([aiMessage, aiMessage])).toEqual({
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
});
});
test("counts later usage-bearing snapshots for the same AI message id", () => {
const earlySnapshot = {
id: "ai-1",
type: "ai",
content: "Streaming...",
} as Message;
const completedSnapshot = {
id: "ai-1",
type: "ai",
content: "Complete answer",
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
} as Message;
expect(accumulateUsage([earlySnapshot, completedSnapshot])).toEqual({
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
});
});
test("keeps header and per-turn aggregation consistent for duplicated UI groups", () => {
const messages = [
{
id: "human-1",
type: "human",
content: "Explain this",
},
{
id: "ai-1",
type: "ai",
content: "<think>checking context</think>Final answer",
usage_metadata: { input_tokens: 20, output_tokens: 7, total_tokens: 27 },
},
] as Message[];
const groups = getMessageGroups(messages);
const usageMessagesByGroupIndex = getAssistantTurnUsageMessages(groups);
const turnUsageMessages = usageMessagesByGroupIndex.at(-1);
expect(groups.map((group) => group.type)).toEqual([
"human",
"assistant:processing",
"assistant",
]);
expect(turnUsageMessages?.map((message) => message.id)).toEqual([
"ai-1",
"ai-1",
]);
expect(accumulateUsage(messages)).toEqual(
accumulateUsage(turnUsageMessages!),
);
expect(accumulateUsage(turnUsageMessages!)).toEqual({
inputTokens: 20,
outputTokens: 7,
totalTokens: 27,
});
});