Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions src/api/providers/__tests__/bedrock.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1279,4 +1279,165 @@ describe("AwsBedrockHandler", () => {
expect(mockCaptureException).toHaveBeenCalled()
})
})

describe("AI SDK v6 usage field paths", () => {
const systemPrompt = "You are a helpful assistant"
const messages: RooMessage[] = [
{
role: "user",
content: "Hello",
},
]

function setupStream(usage: Record<string, unknown>, providerMetadata: Record<string, unknown> = {}) {
async function* mockFullStream() {
yield { type: "text-delta", text: "reply" }
}

mockStreamText.mockReturnValue({
fullStream: mockFullStream(),
usage: Promise.resolve(usage),
providerMetadata: Promise.resolve(providerMetadata),
})
}

describe("cache tokens", () => {
it("should read cache tokens from v6 top-level cachedInputTokens", async () => {
setupStream({ inputTokens: 100, outputTokens: 50, cachedInputTokens: 30 })

const generator = handler.createMessage(systemPrompt, messages)
const chunks: unknown[] = []
for await (const chunk of generator) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c: any) => c.type === "usage") as any
expect(usageChunk).toBeDefined()
expect(usageChunk.cacheReadTokens).toBe(30)
})

it("should read cache tokens from v6 inputTokenDetails.cacheReadTokens", async () => {
setupStream({
inputTokens: 100,
outputTokens: 50,
inputTokenDetails: { cacheReadTokens: 25 },
})

const generator = handler.createMessage(systemPrompt, messages)
const chunks: unknown[] = []
for await (const chunk of generator) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c: any) => c.type === "usage") as any
expect(usageChunk).toBeDefined()
expect(usageChunk.cacheReadTokens).toBe(25)
})

it("should prefer v6 top-level cachedInputTokens over providerMetadata.bedrock", async () => {
setupStream(
{ inputTokens: 100, outputTokens: 50, cachedInputTokens: 30 },
{ bedrock: { usage: { cacheReadInputTokens: 20 } } },
)

const generator = handler.createMessage(systemPrompt, messages)
const chunks: unknown[] = []
for await (const chunk of generator) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c: any) => c.type === "usage") as any
expect(usageChunk).toBeDefined()
expect(usageChunk.cacheReadTokens).toBe(30)
})

it("should fall back to providerMetadata.bedrock.usage.cacheReadInputTokens", async () => {
setupStream(
{ inputTokens: 100, outputTokens: 50 },
{ bedrock: { usage: { cacheReadInputTokens: 20 } } },
)

const generator = handler.createMessage(systemPrompt, messages)
const chunks: unknown[] = []
for await (const chunk of generator) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c: any) => c.type === "usage") as any
expect(usageChunk).toBeDefined()
expect(usageChunk.cacheReadTokens).toBe(20)
})

it("should read cacheWriteTokens from v6 inputTokenDetails.cacheWriteTokens", async () => {
setupStream({
inputTokens: 100,
outputTokens: 50,
inputTokenDetails: { cacheWriteTokens: 15 },
})

const generator = handler.createMessage(systemPrompt, messages)
const chunks: unknown[] = []
for await (const chunk of generator) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c: any) => c.type === "usage") as any
expect(usageChunk).toBeDefined()
expect(usageChunk.cacheWriteTokens).toBe(15)
})
})

describe("reasoning tokens", () => {
it("should read reasoning tokens from v6 top-level reasoningTokens", async () => {
setupStream({ inputTokens: 100, outputTokens: 50, reasoningTokens: 40 })

const generator = handler.createMessage(systemPrompt, messages)
const chunks: unknown[] = []
for await (const chunk of generator) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c: any) => c.type === "usage") as any
expect(usageChunk).toBeDefined()
expect(usageChunk.reasoningTokens).toBe(40)
})

it("should read reasoning tokens from v6 outputTokenDetails.reasoningTokens", async () => {
setupStream({
inputTokens: 100,
outputTokens: 50,
outputTokenDetails: { reasoningTokens: 35 },
})

const generator = handler.createMessage(systemPrompt, messages)
const chunks: unknown[] = []
for await (const chunk of generator) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c: any) => c.type === "usage") as any
expect(usageChunk).toBeDefined()
expect(usageChunk.reasoningTokens).toBe(35)
})

it("should prefer v6 top-level reasoningTokens over outputTokenDetails", async () => {
setupStream({
inputTokens: 100,
outputTokens: 50,
reasoningTokens: 40,
outputTokenDetails: { reasoningTokens: 15 },
})

const generator = handler.createMessage(systemPrompt, messages)
const chunks: unknown[] = []
for await (const chunk of generator) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c: any) => c.type === "usage") as any
expect(usageChunk).toBeDefined()
expect(usageChunk.reasoningTokens).toBe(40)
})
})
})
})
164 changes: 164 additions & 0 deletions src/api/providers/__tests__/gemini.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -472,4 +472,168 @@ describe("GeminiHandler", () => {
expect(mockCaptureException).toHaveBeenCalled()
})
})

describe("AI SDK v6 usage field paths", () => {
const mockMessages: RooMessage[] = [
{
role: "user",
content: "Hello",
},
]
const systemPrompt = "You are a helpful assistant"

function setupStream(usage: Record<string, unknown>) {
const mockFullStream = (async function* () {
yield { type: "text-delta", text: "reply" }
})()

mockStreamText.mockReturnValue({
fullStream: mockFullStream,
usage: Promise.resolve(usage),
providerMetadata: Promise.resolve({}),
})
}

describe("cache tokens", () => {
it("should read cache tokens from v6 top-level cachedInputTokens", async () => {
setupStream({ inputTokens: 100, outputTokens: 50, cachedInputTokens: 30 })

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c) => c.type === "usage")
expect(usageChunk).toBeDefined()
expect(usageChunk!.cacheReadTokens).toBe(30)
})

it("should read cache tokens from v6 inputTokenDetails.cacheReadTokens", async () => {
setupStream({
inputTokens: 100,
outputTokens: 50,
inputTokenDetails: { cacheReadTokens: 25 },
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c) => c.type === "usage")
expect(usageChunk).toBeDefined()
expect(usageChunk!.cacheReadTokens).toBe(25)
})

it("should prefer v6 top-level cachedInputTokens over legacy details", async () => {
setupStream({
inputTokens: 100,
outputTokens: 50,
cachedInputTokens: 30,
details: { cachedInputTokens: 20 },
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c) => c.type === "usage")
expect(usageChunk).toBeDefined()
expect(usageChunk!.cacheReadTokens).toBe(30)
})

it("should fall back to legacy details.cachedInputTokens", async () => {
setupStream({
inputTokens: 100,
outputTokens: 50,
details: { cachedInputTokens: 20 },
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c) => c.type === "usage")
expect(usageChunk).toBeDefined()
expect(usageChunk!.cacheReadTokens).toBe(20)
})
})

describe("reasoning tokens", () => {
it("should read reasoning tokens from v6 top-level reasoningTokens", async () => {
setupStream({ inputTokens: 100, outputTokens: 50, reasoningTokens: 40 })

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c) => c.type === "usage")
expect(usageChunk).toBeDefined()
expect(usageChunk!.reasoningTokens).toBe(40)
})

it("should read reasoning tokens from v6 outputTokenDetails.reasoningTokens", async () => {
setupStream({
inputTokens: 100,
outputTokens: 50,
outputTokenDetails: { reasoningTokens: 35 },
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c) => c.type === "usage")
expect(usageChunk).toBeDefined()
expect(usageChunk!.reasoningTokens).toBe(35)
})

it("should prefer v6 top-level reasoningTokens over legacy details", async () => {
setupStream({
inputTokens: 100,
outputTokens: 50,
reasoningTokens: 40,
details: { reasoningTokens: 15 },
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c) => c.type === "usage")
expect(usageChunk).toBeDefined()
expect(usageChunk!.reasoningTokens).toBe(40)
})

it("should fall back to legacy details.reasoningTokens", async () => {
setupStream({
inputTokens: 100,
outputTokens: 50,
details: { reasoningTokens: 15 },
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunk = chunks.find((c) => c.type === "usage")
expect(usageChunk).toBeDefined()
expect(usageChunk!.reasoningTokens).toBe(15)
})
})
})
})
8 changes: 7 additions & 1 deletion src/api/providers/__tests__/native-ollama.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,13 @@ describe("NativeOllamaHandler", () => {
expect(results).toHaveLength(3)
expect(results[0]).toEqual({ type: "text", text: "Hello" })
expect(results[1]).toEqual({ type: "text", text: " world" })
expect(results[2]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 2 })
expect(results[2]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 2,
totalInputTokens: 10,
totalOutputTokens: 2,
})
})

it("should not include providerOptions by default (no num_ctx)", async () => {
Expand Down
Loading
Loading