Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"bun": ">=1.2.11"
},
"dependencies": {
"@ai-sdk/anthropic": "1.0.8",
"@ai-sdk/google-vertex": "3.0.6",
"@ai-sdk/openai": "2.0.11",
"@codebuff/billing": "workspace:*",
Expand Down Expand Up @@ -56,4 +57,4 @@
"@types/express": "^4.17.13",
"@types/ws": "^8.5.5"
}
}
}
22 changes: 14 additions & 8 deletions backend/src/llm-apis/message-cost-tracker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ export const saveMessage = async (value: {
cacheReadInputTokens?: number
finishedAt: Date
latencyMs: number
usesUserApiKey?: boolean
usesUserApiKey?: boolean // Deprecated: use byokProvider instead
byokProvider?: 'anthropic' | 'gemini' | 'openai' | null
chargeUser?: boolean
costOverrideDollars?: number
agentId?: string
Expand All @@ -604,16 +605,21 @@ export const saveMessage = async (value: {
// Default to 1 cent per credit
const centsPerCredit = 1

// Determine if user API key was used (support both old and new parameters)
const usesUserKey = value.byokProvider !== null && value.byokProvider !== undefined
? !!value.byokProvider
: value.usesUserApiKey ?? false

const costInCents =
value.chargeUser ?? true // default to true
? Math.max(
0,
Math.round(
cost *
100 *
(value.usesUserApiKey ? PROFIT_MARGIN : 1 + PROFIT_MARGIN),
),
)
0,
Math.round(
cost *
100 *
(usesUserKey ? PROFIT_MARGIN : 1 + PROFIT_MARGIN),
),
)
: 0

const creditsUsed = Math.max(0, costInCents)
Expand Down
138 changes: 121 additions & 17 deletions backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { google } from '@ai-sdk/google'
import { openai } from '@ai-sdk/openai'
import { createAnthropic } from '@ai-sdk/anthropic'
import { env } from '@codebuff/internal'
import {
finetunedVertexModels,
geminiModels,
Expand Down Expand Up @@ -32,36 +34,123 @@ import type {
import type { LanguageModel } from 'ai'
import type { z } from 'zod/v4'

// User API keys for BYOK (Bring Your Own Key)
export interface UserApiKeys {
anthropic?: string
gemini?: string
openai?: string
}

export type ByokMode = 'disabled' | 'prefer' | 'require'

export type StreamChunk =
| {
type: 'text'
text: string
}
type: 'text'
text: string
}
| {
type: 'reasoning'
text: string
}
type: 'reasoning'
text: string
}
| { type: 'error'; message: string }

// TODO: We'll want to add all our models here!
const modelToAiSDKModel = (model: Model): LanguageModel => {
/**
* Helper function to determine if a model is an Anthropic model
*/
function isAnthropicModel(model: Model): boolean {
return model.startsWith('anthropic/')
}

/**
* Helper function to determine which provider key was used for BYOK
*/
function determineByokProvider(
model: Model,
userApiKeys?: UserApiKeys,
): 'anthropic' | 'gemini' | 'openai' | null {
if (isAnthropicModel(model) && userApiKeys?.anthropic) return 'anthropic'
if (Object.values(geminiModels).includes(model as GeminiModel) && userApiKeys?.gemini) return 'gemini'
if (Object.values(openaiModels).includes(model as OpenAIModel) && userApiKeys?.openai) return 'openai'
return null
}

/**
* Convert a model string to an AI SDK LanguageModel instance.
* Supports BYOK (Bring Your Own Key) for Anthropic, Gemini, and OpenAI.
*/
const modelToAiSDKModel = (
model: Model,
userApiKeys?: UserApiKeys,
byokMode: ByokMode = 'prefer',
): LanguageModel => {
// Finetuned Vertex models
if (
Object.values(finetunedVertexModels as Record<string, string>).includes(
model,
)
) {
return vertexFinetuned(model)
}

// Gemini models - direct to Google
if (Object.values(geminiModels).includes(model as GeminiModel)) {
return google.languageModel(model)
const apiKey =
byokMode === 'disabled'
? env.GEMINI_API_KEY
: userApiKeys?.gemini ?? env.GEMINI_API_KEY

if (byokMode === 'require' && !userApiKeys?.gemini) {
throw new Error('Gemini API key required but not provided (byokMode: require)')
}

return google.languageModel(model, { apiKey })
}

// OpenAI models - direct to OpenAI
if (model === openaiModels.o3pro || model === openaiModels.o3) {
return openai.responses(model)
const apiKey =
byokMode === 'disabled'
? env.OPENAI_API_KEY
: userApiKeys?.openai ?? env.OPENAI_API_KEY

if (byokMode === 'require' && !userApiKeys?.openai) {
throw new Error('OpenAI API key required but not provided (byokMode: require)')
}

return openai.responses(model, { apiKey })
}

if (Object.values(openaiModels).includes(model as OpenAIModel)) {
return openai.languageModel(model)
const apiKey =
byokMode === 'disabled'
? env.OPENAI_API_KEY
: userApiKeys?.openai ?? env.OPENAI_API_KEY

if (byokMode === 'require' && !userApiKeys?.openai) {
throw new Error('OpenAI API key required but not provided (byokMode: require)')
}

return openai.languageModel(model, { apiKey })
}
// All other models go through OpenRouter

// Anthropic models - direct to Anthropic (if user key provided) or OpenRouter
if (isAnthropicModel(model)) {
// If user has Anthropic key and byokMode allows it, use direct Anthropic API
if (byokMode !== 'disabled' && userApiKeys?.anthropic) {
const anthropic = createAnthropic({ apiKey: userApiKeys.anthropic })
return anthropic.languageModel(model)
}

// If byokMode is 'require', fail if no user key
if (byokMode === 'require') {
throw new Error('Anthropic API key required but not provided (byokMode: require)')
}

// Otherwise, use OpenRouter with system key
return openRouterLanguageModel(model)
}

// All other models go through OpenRouter with system key
return openRouterLanguageModel(model)
}

Expand All @@ -82,6 +171,8 @@ export const promptAiSdkStream = async function* (
maxRetries?: number
onCostCalculated?: (credits: number) => Promise<void>
includeCacheControl?: boolean
userApiKeys?: UserApiKeys
byokMode?: ByokMode
} & Omit<Parameters<typeof streamText>[0], 'model' | 'messages'>,
): AsyncGenerator<StreamChunk, string | null> {
if (
Expand All @@ -103,7 +194,8 @@ export const promptAiSdkStream = async function* (
}
const startTime = Date.now()

let aiSDKModel = modelToAiSDKModel(options.model)
const byokMode = options.byokMode ?? 'prefer'
let aiSDKModel = modelToAiSDKModel(options.model, options.userApiKeys, byokMode)

const response = streamText({
...options,
Expand Down Expand Up @@ -156,8 +248,8 @@ export const promptAiSdkStream = async function* (
if (
(
options.providerOptions?.openrouter as
| OpenRouterProviderOptions
| undefined
| OpenRouterProviderOptions
| undefined
)?.reasoning?.exclude
) {
continue
Expand Down Expand Up @@ -230,6 +322,7 @@ export const promptAiSdkStream = async function* (
}

const messageId = (await response.response).id
const byokProvider = determineByokProvider(options.model, options.userApiKeys)
const creditsUsedPromise = saveMessage({
messageId,
userId: options.userId,
Expand All @@ -246,6 +339,7 @@ export const promptAiSdkStream = async function* (
finishedAt: new Date(),
latencyMs: Date.now() - startTime,
chargeUser: options.chargeUser ?? true,
byokProvider,
costOverrideDollars,
agentId: options.agentId,
})
Expand Down Expand Up @@ -273,6 +367,8 @@ export const promptAiSdk = async function (
onCostCalculated?: (credits: number) => Promise<void>
includeCacheControl?: boolean
maxRetries?: number
userApiKeys?: UserApiKeys
byokMode?: ByokMode
} & Omit<Parameters<typeof generateText>[0], 'model' | 'messages'>,
): Promise<string> {
if (
Expand All @@ -294,7 +390,8 @@ export const promptAiSdk = async function (
}

const startTime = Date.now()
let aiSDKModel = modelToAiSDKModel(options.model)
const byokMode = options.byokMode ?? 'prefer'
let aiSDKModel = modelToAiSDKModel(options.model, options.userApiKeys, byokMode)

const response = await generateText({
...options,
Expand All @@ -305,6 +402,7 @@ export const promptAiSdk = async function (
const inputTokens = response.usage.inputTokens || 0
const outputTokens = response.usage.inputTokens || 0

const byokProvider = determineByokProvider(options.model, options.userApiKeys)
const creditsUsedPromise = saveMessage({
messageId: generateCompactId(),
userId: options.userId,
Expand All @@ -320,6 +418,7 @@ export const promptAiSdk = async function (
latencyMs: Date.now() - startTime,
chargeUser: options.chargeUser ?? true,
agentId: options.agentId,
byokProvider,
})

// Call the cost callback if provided
Expand Down Expand Up @@ -348,6 +447,8 @@ export const promptAiSdkStructured = async function <T>(options: {
onCostCalculated?: (credits: number) => Promise<void>
includeCacheControl?: boolean
maxRetries?: number
userApiKeys?: UserApiKeys
byokMode?: ByokMode
}): Promise<T> {
if (
!checkLiveUserInput(
Expand All @@ -367,7 +468,8 @@ export const promptAiSdkStructured = async function <T>(options: {
return {} as T
}
const startTime = Date.now()
let aiSDKModel = modelToAiSDKModel(options.model)
const byokMode = options.byokMode ?? 'prefer'
let aiSDKModel = modelToAiSDKModel(options.model, options.userApiKeys, byokMode)

const responsePromise = generateObject<z.ZodType<T>, 'object'>({
...options,
Expand All @@ -383,6 +485,7 @@ export const promptAiSdkStructured = async function <T>(options: {
const inputTokens = response.usage.inputTokens || 0
const outputTokens = response.usage.inputTokens || 0

const byokProvider = determineByokProvider(options.model, options.userApiKeys)
const creditsUsedPromise = saveMessage({
messageId: generateCompactId(),
userId: options.userId,
Expand All @@ -398,6 +501,7 @@ export const promptAiSdkStructured = async function <T>(options: {
latencyMs: Date.now() - startTime,
chargeUser: options.chargeUser ?? true,
agentId: options.agentId,
byokProvider,
})

// Call the cost callback if provided
Expand Down
41 changes: 41 additions & 0 deletions backend/src/main-prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { getAgentTemplate } from './templates/agent-registry'
import { logger } from './util/logger'
import { expireMessages } from './util/messages'
import { requestToolCall } from './websockets/websocket-action'
import { retrieveAndDecryptApiKey } from '@codebuff/common/api-keys/crypto'

import type { AgentTemplate } from './templates/types'
import type { ClientAction } from '@codebuff/common/actions'
Expand All @@ -19,6 +20,7 @@ import type {
AgentOutput,
} from '@codebuff/common/types/session-state'
import type { WebSocket } from 'ws'
import type { UserApiKeys, ByokMode } from './llm-apis/vercel-ai-sdk/ai-sdk'

export interface MainPromptOptions {
userId: string | undefined
Expand All @@ -27,6 +29,38 @@ export interface MainPromptOptions {
localAgentTemplates: Record<string, AgentTemplate>
}

/**
* Retrieves user API keys from the database for BYOK (Bring Your Own Key)
* Merges SDK-provided keys with database keys, with SDK keys taking precedence
*/
async function getUserApiKeys(
userId: string | undefined,
sdkKeys?: UserApiKeys,
): Promise<UserApiKeys | undefined> {
if (!userId) {
return sdkKeys
}

try {
// Retrieve keys from database
const [anthropicKey, geminiKey, openaiKey] = await Promise.all([
retrieveAndDecryptApiKey(userId, 'anthropic'),
retrieveAndDecryptApiKey(userId, 'gemini'),
retrieveAndDecryptApiKey(userId, 'openai'),
])

// Merge with SDK keys (SDK keys take precedence)
return {
anthropic: sdkKeys?.anthropic ?? anthropicKey ?? undefined,
gemini: sdkKeys?.gemini ?? geminiKey ?? undefined,
openai: sdkKeys?.openai ?? openaiKey ?? undefined,
}
} catch (error) {
logger.error({ error, userId }, 'Failed to retrieve user API keys')
return sdkKeys
}
}

export const mainPrompt = async (
ws: WebSocket,
action: ClientAction<'prompt'>,
Expand All @@ -47,9 +81,14 @@ export const mainPrompt = async (
promptId,
agentId,
promptParams,
userApiKeys: sdkUserApiKeys,
byokMode,
} = action
const { fileContext, mainAgentState } = sessionState

// Retrieve and merge user API keys (SDK keys take precedence over DB keys)
const userApiKeys = await getUserApiKeys(userId, sdkUserApiKeys)

if (prompt) {
// Check if this is a direct terminal command
const startTime = Date.now()
Expand Down Expand Up @@ -203,6 +242,8 @@ export const mainPrompt = async (
clientSessionId,
onResponseChunk,
localAgentTemplates,
userApiKeys,
byokMode,
})

logger.debug({ agentState, output }, 'Main prompt finished')
Expand Down
Loading