diff --git a/README.md b/README.md index 8f4d201aa..82646d1a4 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,7 @@ Gateway seamlessly integrates with popular agent frameworks. [Read the documenta |------------------------------|--------|-------------|---------|------|---------------|-------------------| | [Autogen](https://docs.portkey.ai/docs/welcome/agents/autogen) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [CrewAI](https://docs.portkey.ai/docs/welcome/agents/crewai) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [LangChain](https://docs.portkey.ai/docs/welcome/agents/langchain-agents) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [Phidata](https://docs.portkey.ai/docs/welcome/agents/phidata) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [Llama Index](https://docs.portkey.ai/docs/welcome/agents/llama-agents) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [Control Flow](https://docs.portkey.ai/docs/welcome/agents/control-flow) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/jest.config.cjs b/jest.config.cjs deleted file mode 100644 index b413e106d..000000000 --- a/jest.config.cjs +++ /dev/null @@ -1,5 +0,0 @@ -/** @type {import('ts-jest').JestConfigWithTsJest} */ -module.exports = { - preset: 'ts-jest', - testEnvironment: 'node', -}; \ No newline at end of file diff --git a/jest.config.js b/jest.config.js new file mode 100644 index 000000000..f22307146 --- /dev/null +++ b/jest.config.js @@ -0,0 +1,8 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} **/ +export default { + testEnvironment: 'node', + transform: { + '^.+.tsx?$': ['ts-jest', {}], + }, + testTimeout: 30000, // Set default timeout to 30 seconds +}; diff --git a/package-lock.json b/package-lock.json index dd4e0207d..65ed1f2f4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@portkey-ai/gateway", - "version": "1.6.1", + "version": "1.7.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@portkey-ai/gateway", - "version": "1.6.1", + "version": "1.7.0", "license": "MIT", "dependencies": { "@aws-crypto/sha256-js": "^5.2.0", @@ -33,7 +33,7 @@ "jest": "^29.7.0", "prettier": "3.2.5", "rollup": "^4.9.1", - "ts-jest": "^29.1.2", + "ts-jest": "^29.2.4", "tsx": "^4.7.0", "typescript-eslint": "^8.1.0", "wrangler": "^3.48.0" @@ -2621,6 +2621,12 @@ "printable-characters": "^1.0.42" } }, + "node_modules/async": { + "version": "3.2.6", + "resolved": "https://registry.npmjs.org/async/-/async-3.2.6.tgz", + "integrity": "sha512-htCUDlxyyCLMgaM3xXg0C0LW2xqfuQ6p05pCEIsXuyQ+a1koYKTuBMzRNwmybfLgvJDMd0r1LTn4+E0Ti6C2AA==", + "dev": true + }, "node_modules/async-retry": { "version": "1.3.3", "resolved": "https://registry.npmjs.org/async-retry/-/async-retry-1.3.3.tgz", @@ -3152,6 +3158,21 @@ "node": ">=8" } }, + "node_modules/ejs": { + "version": "3.1.10", + "resolved": "https://registry.npmjs.org/ejs/-/ejs-3.1.10.tgz", + "integrity": "sha512-UeJmFfOrAQS8OJWPZ4qtgHyWExa088/MtK5UEyoJGFH67cDEXkZSviOiKRCZ4Xij0zxI3JECgYs3oKx+AizQBA==", + "dev": true, + "dependencies": { + "jake": "^10.8.5" + }, + "bin": { + "ejs": "bin/cli.js" + }, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/electron-to-chromium": { "version": "1.4.763", "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.763.tgz", @@ -3630,6 +3651,36 @@ "node": ">=16.0.0" } }, + "node_modules/filelist": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/filelist/-/filelist-1.0.4.tgz", + "integrity": "sha512-w1cEuf3S+DrLCQL7ET6kz+gmlJdbq9J7yXCSjK/OZCPA+qEN1WyF4ZAf0YYJa4/shHJra2t/d/r8SV4Ji+x+8Q==", + "dev": true, + "dependencies": { + "minimatch": "^5.0.1" + } + }, + "node_modules/filelist/node_modules/brace-expansion": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", + "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "dev": true, + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/filelist/node_modules/minimatch": { + "version": "5.1.6", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", + "integrity": "sha512-lKwV/1brpG6mBUFHtb7NUmtABCb2WZZmm2wNiOA5hAb8VdCS4B3dtMWyvcoViccwAW/COERjXLt0zP1zXUN26g==", + "dev": true, + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=10" + } + }, "node_modules/fill-range": { "version": "7.0.1", "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", @@ -4168,6 +4219,24 @@ "node": ">=8" } }, + "node_modules/jake": { + "version": "10.9.2", + "resolved": "https://registry.npmjs.org/jake/-/jake-10.9.2.tgz", + "integrity": "sha512-2P4SQ0HrLQ+fw6llpLnOaGAvN2Zu6778SJMrCUwns4fOoG9ayrTiZk3VV8sCPkVZF8ab0zksVpS8FDY5pRCNBA==", + "dev": true, + "dependencies": { + "async": "^3.2.3", + "chalk": "^4.0.2", + "filelist": "^1.0.4", + "minimatch": "^3.1.2" + }, + "bin": { + "jake": "bin/cli.js" + }, + "engines": { + "node": ">=10" + } + }, "node_modules/jest": { "version": "29.7.0", "resolved": "https://registry.npmjs.org/jest/-/jest-29.7.0.tgz", @@ -5978,12 +6047,13 @@ } }, "node_modules/ts-jest": { - "version": "29.1.2", - "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-29.1.2.tgz", - "integrity": "sha512-br6GJoH/WUX4pu7FbZXuWGKGNDuU7b8Uj77g/Sp7puZV6EXzuByl6JrECvm0MzVzSTkSHWTihsXt+5XYER5b+g==", + "version": "29.2.4", + "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-29.2.4.tgz", + "integrity": "sha512-3d6tgDyhCI29HlpwIq87sNuI+3Q6GLTTCeYRHCs7vDz+/3GCMwEtV9jezLyl4ZtnBgx00I7hm8PCP8cTksMGrw==", "dev": true, "dependencies": { "bs-logger": "0.x", + "ejs": "^3.1.10", "fast-json-stable-stringify": "2.x", "jest-util": "^29.0.0", "json5": "^2.2.3", @@ -5996,10 +6066,11 @@ "ts-jest": "cli.js" }, "engines": { - "node": "^16.10.0 || ^18.0.0 || >=20.0.0" + "node": "^14.15.0 || ^16.10.0 || ^18.0.0 || >=20.0.0" }, "peerDependencies": { "@babel/core": ">=7.0.0-beta.0 <8", + "@jest/transform": "^29.0.0", "@jest/types": "^29.0.0", "babel-jest": "^29.0.0", "jest": "^29.0.0", @@ -6009,6 +6080,9 @@ "@babel/core": { "optional": true }, + "@jest/transform": { + "optional": true + }, "@jest/types": { "optional": true }, diff --git a/package.json b/package.json index 6d1626d13..5c038e3fc 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@portkey-ai/gateway", - "version": "1.6.1", + "version": "1.7.0", "description": "A fast AI gateway by Portkey", "repository": { "type": "git", @@ -60,7 +60,7 @@ "jest": "^29.7.0", "prettier": "3.2.5", "rollup": "^4.9.1", - "ts-jest": "^29.1.2", + "ts-jest": "^29.2.4", "tsx": "^4.7.0", "typescript-eslint": "^8.1.0", "wrangler": "^3.48.0" diff --git a/plugins/aporia/validateProject.ts b/plugins/aporia/validateProject.ts index 9ca57f068..f059d9cf1 100644 --- a/plugins/aporia/validateProject.ts +++ b/plugins/aporia/validateProject.ts @@ -4,7 +4,7 @@ import { PluginHandler, PluginParameters, } from '../types'; -import { getText, post } from '../utils'; +import { post } from '../utils'; export const APORIA_BASE_URL = 'https://gr-prd.aporia.com'; diff --git a/src/errors/RouterError.ts b/src/errors/RouterError.ts new file mode 100644 index 000000000..9d5c9e5bd --- /dev/null +++ b/src/errors/RouterError.ts @@ -0,0 +1,9 @@ +export class RouterError extends Error { + constructor( + message: string, + public cause?: Error + ) { + super(message); + this.name = 'RouterError'; + } +} diff --git a/src/globals.ts b/src/globals.ts index 7e4381756..4d3278591 100644 --- a/src/globals.ts +++ b/src/globals.ts @@ -1,3 +1,5 @@ +import { endpointStrings } from './providers/types'; + export const POWERED_BY: string = 'portkey'; export const HEADER_KEYS: Record = { @@ -6,6 +8,7 @@ export const HEADER_KEYS: Record = { PROVIDER: `x-${POWERED_BY}-provider`, TRACE_ID: `x-${POWERED_BY}-trace-id`, CACHE: `x-${POWERED_BY}-cache`, + METADATA: `x-${POWERED_BY}-metadata`, FORWARD_HEADERS: `x-${POWERED_BY}-forward-headers`, CUSTOM_HOST: `x-${POWERED_BY}-custom-host`, REQUEST_TIMEOUT: `x-${POWERED_BY}-request-timeout`, @@ -104,3 +107,8 @@ export const CONTENT_TYPES = { HTML: 'text/html', GENERIC_IMAGE_PATTERN: 'image/', }; + +export const MULTIPART_FORM_DATA_ENDPOINTS: endpointStrings[] = [ + 'createTranscription', + 'createTranslation', +]; diff --git a/src/handlers/chatCompletionsHandler.ts b/src/handlers/chatCompletionsHandler.ts index 077d04a04..5c461b840 100644 --- a/src/handlers/chatCompletionsHandler.ts +++ b/src/handlers/chatCompletionsHandler.ts @@ -1,3 +1,4 @@ +import { RouterError } from '../errors/RouterError'; import { constructConfigFromRequestHeaders, tryTargetsRecursively, @@ -30,13 +31,21 @@ export async function chatCompletionsHandler(c: Context): Promise { return tryTargetsResponse; } catch (err: any) { console.log('chatCompletion error', err.message); + let statusCode = 500; + let errorMessage = 'Something went wrong'; + + if (err instanceof RouterError) { + statusCode = 400; + errorMessage = err.message; + } + return new Response( JSON.stringify({ status: 'failure', - message: 'Something went wrong', + message: errorMessage, }), { - status: 500, + status: statusCode, headers: { 'content-type': 'application/json', }, diff --git a/src/handlers/completionsHandler.ts b/src/handlers/completionsHandler.ts index 6b058478d..a1a896484 100644 --- a/src/handlers/completionsHandler.ts +++ b/src/handlers/completionsHandler.ts @@ -1,3 +1,4 @@ +import { RouterError } from '../errors/RouterError'; import { constructConfigFromRequestHeaders, tryTargetsRecursively, @@ -31,6 +32,14 @@ export async function completionsHandler(c: Context): Promise { return tryTargetsResponse; } catch (err: any) { console.log('completion error', err.message); + let statusCode = 500; + let errorMessage = 'Something went wrong'; + + if (err instanceof RouterError) { + statusCode = 400; + errorMessage = err.message; + } + return new Response( JSON.stringify({ status: 'failure', diff --git a/src/handlers/createSpeechHandler.ts b/src/handlers/createSpeechHandler.ts new file mode 100644 index 000000000..efb142da2 --- /dev/null +++ b/src/handlers/createSpeechHandler.ts @@ -0,0 +1,46 @@ +import { + constructConfigFromRequestHeaders, + tryTargetsRecursively, +} from './handlerUtils'; +import { Context } from 'hono'; + +/** + * Handles the '/audio/speech' API request by selecting the appropriate provider(s) and making the request to them. + * + * @param {Context} c - The Cloudflare Worker context. + * @returns {Promise} - The response from the provider. + * @throws Will throw an error if no provider options can be determined or if the request to the provider(s) fails. + * @throws Will throw an 500 error if the handler fails due to some reasons + */ +export async function createSpeechHandler(c: Context): Promise { + try { + let request = await c.req.json(); + let requestHeaders = Object.fromEntries(c.req.raw.headers); + const camelCaseConfig = constructConfigFromRequestHeaders(requestHeaders); + const tryTargetsResponse = await tryTargetsRecursively( + c, + camelCaseConfig ?? {}, + request, + requestHeaders, + 'createSpeech', + 'POST', + 'config' + ); + + return tryTargetsResponse; + } catch (err: any) { + console.log('createSpeech error', err.message); + return new Response( + JSON.stringify({ + status: 'failure', + message: 'Something went wrong', + }), + { + status: 500, + headers: { + 'content-type': 'application/json', + }, + } + ); + } +} diff --git a/src/handlers/createTranscriptionHandler.ts b/src/handlers/createTranscriptionHandler.ts new file mode 100644 index 000000000..7060372c4 --- /dev/null +++ b/src/handlers/createTranscriptionHandler.ts @@ -0,0 +1,48 @@ +import { + constructConfigFromRequestHeaders, + tryTargetsRecursively, +} from './handlerUtils'; +import { Context } from 'hono'; + +/** + * Handles the '/audio/transcriptions' API request by selecting the appropriate provider(s) and making the request to them. + * + * @param {Context} c - The Cloudflare Worker context. + * @returns {Promise} - The response from the provider. + * @throws Will throw an error if no provider options can be determined or if the request to the provider(s) fails. + * @throws Will throw an 500 error if the handler fails due to some reasons + */ +export async function createTranscriptionHandler( + c: Context +): Promise { + try { + let request = await c.req.raw.formData(); + let requestHeaders = Object.fromEntries(c.req.raw.headers); + const camelCaseConfig = constructConfigFromRequestHeaders(requestHeaders); + const tryTargetsResponse = await tryTargetsRecursively( + c, + camelCaseConfig ?? {}, + request, + requestHeaders, + 'createTranscription', + 'POST', + 'config' + ); + + return tryTargetsResponse; + } catch (err: any) { + console.log('createTranscription error', err.message); + return new Response( + JSON.stringify({ + status: 'failure', + message: 'Something went wrong', + }), + { + status: 500, + headers: { + 'content-type': 'application/json', + }, + } + ); + } +} diff --git a/src/handlers/createTranslationHandler.ts b/src/handlers/createTranslationHandler.ts new file mode 100644 index 000000000..a0a7aee46 --- /dev/null +++ b/src/handlers/createTranslationHandler.ts @@ -0,0 +1,46 @@ +import { + constructConfigFromRequestHeaders, + tryTargetsRecursively, +} from './handlerUtils'; +import { Context } from 'hono'; + +/** + * Handles the '/audio/translations' API request by selecting the appropriate provider(s) and making the request to them. + * + * @param {Context} c - The Cloudflare Worker context. + * @returns {Promise} - The response from the provider. + * @throws Will throw an error if no provider options can be determined or if the request to the provider(s) fails. + * @throws Will throw an 500 error if the handler fails due to some reasons + */ +export async function createTranslationHandler(c: Context): Promise { + try { + let request = await c.req.raw.formData(); + let requestHeaders = Object.fromEntries(c.req.raw.headers); + const camelCaseConfig = constructConfigFromRequestHeaders(requestHeaders); + const tryTargetsResponse = await tryTargetsRecursively( + c, + camelCaseConfig ?? {}, + request, + requestHeaders, + 'createTranslation', + 'POST', + 'config' + ); + + return tryTargetsResponse; + } catch (err: any) { + console.log('createTranslation error', err.message); + return new Response( + JSON.stringify({ + status: 'failure', + message: 'Something went wrong', + }), + { + status: 500, + headers: { + 'content-type': 'application/json', + }, + } + ); + } +} diff --git a/src/handlers/embeddingsHandler.ts b/src/handlers/embeddingsHandler.ts index 2535c07ac..a6caddd56 100644 --- a/src/handlers/embeddingsHandler.ts +++ b/src/handlers/embeddingsHandler.ts @@ -1,3 +1,4 @@ +import { RouterError } from '../errors/RouterError'; import { constructConfigFromRequestHeaders, tryTargetsRecursively, @@ -30,7 +31,15 @@ export async function embeddingsHandler(c: Context): Promise { return tryTargetsResponse; } catch (err: any) { - console.log('completion error', err.message); + console.log('embeddings error', err.message); + let statusCode = 500; + let errorMessage = 'Something went wrong'; + + if (err instanceof RouterError) { + statusCode = 400; + errorMessage = err.message; + } + return new Response( JSON.stringify({ status: 'failure', diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index 4b08ce541..a057cadc5 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -9,6 +9,8 @@ import { RETRY_STATUS_CODES, GOOGLE_VERTEX_AI, OPEN_AI, + MULTIPART_FORM_DATA_ENDPOINTS, + CONTENT_TYPES, } from '../globals'; import Providers from '../providers'; import { ProviderAPIConfig, endpointStrings } from '../providers/types'; @@ -18,14 +20,16 @@ import { Options, Params, ShortConfig, + StrategyModes, Targets, } from '../types/requestBody'; import { convertKeysToCamelCase } from '../utils'; import { retryRequest } from './retryHandler'; -import { env } from 'hono/adapter'; +import { env, getRuntimeKey } from 'hono/adapter'; import { afterRequestHookHandler, responseHandler } from './responseHandlers'; -import { getRuntimeKey } from 'hono/adapter'; import { HookSpan, HooksManager } from '../middlewares/hooks'; +import { ConditionalRouter } from '../services/conditionalRouter'; +import { RouterError } from '../errors/RouterError'; /** * Constructs the request options for the API call. @@ -68,9 +72,13 @@ export function constructRequest( method, headers, }; + const contentType = headers['content-type']; + const isGetMethod = method === 'GET'; + const isMultipartFormData = contentType === CONTENT_TYPES.MULTIPART_FORM_DATA; + const shouldDeleteContentTypeHeader = + (isGetMethod || isMultipartFormData) && fetchOptions.headers; - // If the method is GET, delete the content-type header - if (method === 'GET' && fetchOptions.headers) { + if (shouldDeleteContentTypeHeader) { let headers = fetchOptions.headers as Record; delete headers['content-type']; } @@ -440,7 +448,7 @@ export async function tryPostProxy( export async function tryPost( c: Context, providerOption: Options, - inputParams: Params, + inputParams: Params | FormData, requestHeaders: Record, fn: endpointStrings, currentIndex: number | string @@ -475,6 +483,7 @@ export async function tryPost( const transformedRequestBody = transformToProviderRequest( provider, params, + inputParams, fn ); @@ -514,7 +523,9 @@ export async function tryPost( requestHeaders ); - fetchOptions.body = JSON.stringify(transformedRequestBody); + fetchOptions.body = MULTIPART_FORM_DATA_ENDPOINTS.includes(fn) + ? (transformedRequestBody as FormData) + : JSON.stringify(transformedRequestBody); providerOption.retry = { attempts: providerOption.retry?.attempts ?? 0, @@ -702,7 +713,7 @@ export async function tryProvidersInSequence( export async function tryTargetsRecursively( c: Context, targetGroup: Targets, - request: Params, + request: Params | FormData, requestHeaders: Record, fn: endpointStrings, method: string, @@ -795,7 +806,7 @@ export async function tryTargetsRecursively( let response; switch (strategyMode) { - case 'fallback': + case StrategyModes.FALLBACK: for (let [index, target] of currentTarget.targets.entries()) { response = await tryTargetsRecursively( c, @@ -816,7 +827,7 @@ export async function tryTargetsRecursively( } break; - case 'loadbalance': + case StrategyModes.LOADBALANCE: currentTarget.targets.forEach((t: Options) => { if (t.weight === undefined) { t.weight = 1; @@ -847,7 +858,35 @@ export async function tryTargetsRecursively( } break; - case 'single': + case StrategyModes.CONDITIONAL: + let metadata: Record; + try { + metadata = JSON.parse(requestHeaders[HEADER_KEYS.METADATA]); + } catch (err) { + metadata = {}; + } + let conditionalRouter: ConditionalRouter; + let finalTarget: Targets; + try { + conditionalRouter = new ConditionalRouter(currentTarget, { metadata }); + finalTarget = conditionalRouter.resolveTarget(); + } catch (conditionalRouter: any) { + throw new RouterError(conditionalRouter.message); + } + + response = await tryTargetsRecursively( + c, + finalTarget, + request, + requestHeaders, + fn, + method, + `${currentJsonPath}.targets[${finalTarget.index}]`, + currentInheritedConfig + ); + break; + + case StrategyModes.SINGLE: response = await tryTargetsRecursively( c, currentTarget.targets[0], @@ -1017,6 +1056,7 @@ export function constructConfigFromRequestHeaders( 'params', 'checks', 'vertex_service_account_json', + 'conditions', ]) as any; } diff --git a/src/handlers/imageGenerationsHandler.ts b/src/handlers/imageGenerationsHandler.ts index eab18482c..7d47daf08 100644 --- a/src/handlers/imageGenerationsHandler.ts +++ b/src/handlers/imageGenerationsHandler.ts @@ -1,3 +1,4 @@ +import { RouterError } from '../errors/RouterError'; import { constructConfigFromRequestHeaders, tryTargetsRecursively, @@ -31,6 +32,13 @@ export async function imageGenerationsHandler(c: Context): Promise { return tryTargetsResponse; } catch (err: any) { console.log('imageGenerate error', err.message); + let statusCode = 500; + let errorMessage = 'Something went wrong'; + + if (err instanceof RouterError) { + statusCode = 400; + errorMessage = err.message; + } return new Response( JSON.stringify({ status: 'failure', diff --git a/src/handlers/proxyGetHandler.ts b/src/handlers/proxyGetHandler.ts index 4fbcd51bf..736f71af0 100644 --- a/src/handlers/proxyGetHandler.ts +++ b/src/handlers/proxyGetHandler.ts @@ -1,4 +1,4 @@ -import { Context, HonoRequest } from 'hono'; +import { Context } from 'hono'; import { retryRequest } from './retryHandler'; import Providers from '../providers'; import { @@ -7,7 +7,6 @@ import { HEADER_KEYS, RETRY_STATUS_CODES, POWERED_BY, - RESPONSE_HEADER_KEYS, AZURE_OPEN_AI, } from '../globals'; import { updateResponseHeaders } from './handlerUtils'; diff --git a/src/handlers/streamHandler.ts b/src/handlers/streamHandler.ts index 1b36b4211..a8fc462d9 100644 --- a/src/handlers/streamHandler.ts +++ b/src/handlers/streamHandler.ts @@ -6,7 +6,9 @@ import { GOOGLE, REQUEST_TIMEOUT_STATUS_CODE, PRECONDITION_CHECK_FAILED_STATUS_CODE, + GOOGLE_VERTEX_AI, } from '../globals'; +import { VertexLlamaChatCompleteStreamChunkTransform } from '../providers/google-vertex-ai/chatComplete'; import { OpenAIChatCompleteResponse } from '../providers/openai/chatComplete'; import { OpenAICompleteResponse } from '../providers/openai/complete'; import { getStreamModeSplitPattern, type SplitPatternType } from '../utils'; @@ -306,15 +308,15 @@ export function handleStreamingMode( } // Convert GEMINI/COHERE json stream to text/event-stream for non-proxy calls - if ( - [ - // - GOOGLE, - COHERE, - BEDROCK, - ].includes(proxyProvider) && - responseTransformer - ) { + const isGoogleCohereOrBedrock = [GOOGLE, COHERE, BEDROCK].includes( + proxyProvider + ); + const isVertexLlama = + proxyProvider === GOOGLE_VERTEX_AI && + responseTransformer?.name === + VertexLlamaChatCompleteStreamChunkTransform.name; + const isJsonStream = isGoogleCohereOrBedrock || isVertexLlama; + if (isJsonStream && responseTransformer) { return new Response(readable, { ...response, headers: new Headers({ diff --git a/src/index.ts b/src/index.ts index 606adae1f..2d1e0ff03 100644 --- a/src/index.ts +++ b/src/index.ts @@ -23,7 +23,10 @@ import { compress } from 'hono/compress'; import { getRuntimeKey } from 'hono/adapter'; import { imageGenerationsHandler } from './handlers/imageGenerationsHandler'; import { memoryCache } from './middlewares/cache'; +import { createSpeechHandler } from './handlers/createSpeechHandler'; import conf from '../conf.json'; +import { createTranscriptionHandler } from './handlers/createTranscriptionHandler'; +import { createTranslationHandler } from './handlers/createTranslationHandler'; // Create a new Hono server instance const app = new Hono(); @@ -122,6 +125,28 @@ app.post('/v1/embeddings', requestValidator, embeddingsHandler); */ app.post('/v1/images/generations', requestValidator, imageGenerationsHandler); +/** + * POST route for '/v1/audio/speech'. + * Handles requests by passing them to the createSpeechHandler. + */ +app.post('/v1/audio/speech', requestValidator, createSpeechHandler); + +/** + * POST route for '/v1/audio/transcriptions'. + * Handles requests by passing them to the createTranscriptionHandler. + */ +app.post( + '/v1/audio/transcriptions', + requestValidator, + createTranscriptionHandler +); + +/** + * POST route for '/v1/audio/translations'. + * Handles requests by passing them to the createTranslationHandler. + */ +app.post('/v1/audio/translations', requestValidator, createTranslationHandler); + /** * POST route for '/v1/prompts/:id/completions'. * Handles portkey prompt completions route diff --git a/src/middlewares/hooks/index.ts b/src/middlewares/hooks/index.ts index dd01a67dc..98ad7ace4 100644 --- a/src/middlewares/hooks/index.ts +++ b/src/middlewares/hooks/index.ts @@ -84,7 +84,12 @@ export class HookSpan { } else if (requestParams?.messages?.length) { const lastMessage = requestParams.messages[requestParams.messages.length - 1]; - return lastMessage.content.text || lastMessage.content; + const concatenatedText = Array.isArray(lastMessage.content) + ? lastMessage.content + .map((contentPart: any) => contentPart.text) + .join('\n') + : ''; + return concatenatedText || lastMessage.content; } return ''; } diff --git a/src/middlewares/requestValidator/schema/config.ts b/src/middlewares/requestValidator/schema/config.ts index 5e04883a8..c73589787 100644 --- a/src/middlewares/requestValidator/schema/config.ts +++ b/src/middlewares/requestValidator/schema/config.ts @@ -1,4 +1,4 @@ -import { z } from 'zod'; +import { any, z } from 'zod'; import { OLLAMA, VALID_PROVIDERS, GOOGLE_VERTEX_AI } from '../../../globals'; export const configSchema: any = z @@ -8,13 +8,25 @@ export const configSchema: any = z mode: z .string() .refine( - (value) => ['single', 'loadbalance', 'fallback'].includes(value), + (value) => + ['single', 'loadbalance', 'fallback', 'conditional'].includes( + value + ), { message: - "Invalid 'mode' value. Must be one of: single, loadbalance, fallback", + "Invalid 'mode' value. Must be one of: single, loadbalance, fallback, conditional", } ), on_status_codes: z.array(z.number()).optional(), + conditions: z + .array( + z.object({ + query: z.object({}), + then: z.string(), + }) + ) + .optional(), + default: z.string().optional(), }) .optional(), provider: z diff --git a/src/providers/azure-openai/api.ts b/src/providers/azure-openai/api.ts index e60fa37f2..7ccb5a58d 100644 --- a/src/providers/azure-openai/api.ts +++ b/src/providers/azure-openai/api.ts @@ -1,4 +1,3 @@ -import { Options } from '../../types/requestBody'; import { ProviderAPIConfig } from '../types'; const AzureOpenAIAPIConfig: ProviderAPIConfig = { @@ -6,9 +5,13 @@ const AzureOpenAIAPIConfig: ProviderAPIConfig = { const { resourceName, deploymentId } = providerOptions; return `https://${resourceName}.openai.azure.com/openai/deployments/${deploymentId}`; }, - headers: ({ providerOptions }) => { - const { apiKey } = providerOptions; - return { 'api-key': `${apiKey}` }; + headers: ({ providerOptions, fn }) => { + const headersObj: Record = { + 'api-key': `${providerOptions.apiKey}`, + }; + if (fn === 'createTranscription' || fn === 'createTranslation') + headersObj['Content-Type'] = 'multipart/form-data'; + return headersObj; }, getEndpoint: ({ providerOptions, fn }) => { const { apiVersion, urlToFetch } = providerOptions; @@ -23,6 +26,12 @@ const AzureOpenAIAPIConfig: ProviderAPIConfig = { mappedFn = 'embed'; } else if (urlToFetch?.indexOf('/images/generations') > -1) { mappedFn = 'imageGenerate'; + } else if (urlToFetch?.indexOf('/audio/speech') > -1) { + mappedFn = 'createSpeech'; + } else if (urlToFetch?.indexOf('/audio/transcriptions') > -1) { + mappedFn = 'createTranscription'; + } else if (urlToFetch?.indexOf('/audio/translations') > -1) { + mappedFn = 'createTranslation'; } } @@ -39,6 +48,15 @@ const AzureOpenAIAPIConfig: ProviderAPIConfig = { case 'imageGenerate': { return `/images/generations?api-version=${apiVersion}`; } + case 'createSpeech': { + return `/audio/speech?api-version=${apiVersion}`; + } + case 'createTranscription': { + return `/audio/transcriptions?api-version=${apiVersion}`; + } + case 'createTranslation': { + return `/audio/translations?api-version=${apiVersion}`; + } default: return ''; } diff --git a/src/providers/azure-openai/chatComplete.ts b/src/providers/azure-openai/chatComplete.ts index 28e1ca5fd..b07b79f4c 100644 --- a/src/providers/azure-openai/chatComplete.ts +++ b/src/providers/azure-openai/chatComplete.ts @@ -1,5 +1,5 @@ import { AZURE_OPEN_AI } from '../../globals'; -import { OpenAIErrorResponseTransform } from '../openai/chatComplete'; +import { OpenAIErrorResponseTransform } from '../openai/utils'; import { ChatCompletionResponse, ErrorResponse, diff --git a/src/providers/azure-openai/complete.ts b/src/providers/azure-openai/complete.ts index 5cbc3e3b7..61ae85d51 100644 --- a/src/providers/azure-openai/complete.ts +++ b/src/providers/azure-openai/complete.ts @@ -1,5 +1,5 @@ import { AZURE_OPEN_AI } from '../../globals'; -import { OpenAIErrorResponseTransform } from '../openai/chatComplete'; +import { OpenAIErrorResponseTransform } from '../openai/utils'; import { CompletionResponse, ErrorResponse, ProviderConfig } from '../types'; // TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. diff --git a/src/providers/azure-openai/createSpeech.ts b/src/providers/azure-openai/createSpeech.ts new file mode 100644 index 000000000..b2d85a837 --- /dev/null +++ b/src/providers/azure-openai/createSpeech.ts @@ -0,0 +1,41 @@ +import { OPEN_AI } from '../../globals'; +import { ErrorResponse, ProviderConfig } from '../types'; +import { OpenAIErrorResponseTransform } from '../openai/utils'; + +export const AzureOpenAICreateSpeechConfig: ProviderConfig = { + model: { + param: 'model', + required: true, + default: 'tts-1', + }, + input: { + param: 'input', + required: true, + }, + voice: { + param: 'voice', + required: true, + default: 'alloy', + }, + response_format: { + param: 'response_format', + required: false, + default: 'mp3', + }, + speed: { + param: 'speed', + required: false, + default: 1, + }, +}; + +export const AzureOpenAICreateSpeechResponseTransform: ( + response: Response | ErrorResponse, + responseStatus: number +) => Response | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200 && 'error' in response) { + return OpenAIErrorResponseTransform(response, OPEN_AI); + } + + return response; +}; diff --git a/src/providers/azure-openai/createTranscription.ts b/src/providers/azure-openai/createTranscription.ts new file mode 100644 index 000000000..0b549be75 --- /dev/null +++ b/src/providers/azure-openai/createTranscription.ts @@ -0,0 +1,14 @@ +import { OPEN_AI } from '../../globals'; +import { ErrorResponse } from '../types'; +import { OpenAIErrorResponseTransform } from '../openai/utils'; + +export const AzureOpenAICreateTranscriptionResponseTransform: ( + response: Response | ErrorResponse, + responseStatus: number +) => Response | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200 && 'error' in response) { + return OpenAIErrorResponseTransform(response, OPEN_AI); + } + + return response; +}; diff --git a/src/providers/azure-openai/createTranslation.ts b/src/providers/azure-openai/createTranslation.ts new file mode 100644 index 000000000..8fcd8fea5 --- /dev/null +++ b/src/providers/azure-openai/createTranslation.ts @@ -0,0 +1,14 @@ +import { OPEN_AI } from '../../globals'; +import { ErrorResponse } from '../types'; +import { OpenAIErrorResponseTransform } from '../openai/utils'; + +export const AzureOpenAICreateTranslationResponseTransform: ( + response: Response | ErrorResponse, + responseStatus: number +) => Response | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200 && 'error' in response) { + return OpenAIErrorResponseTransform(response, OPEN_AI); + } + + return response; +}; diff --git a/src/providers/azure-openai/embed.ts b/src/providers/azure-openai/embed.ts index 45984acaa..d475aefae 100644 --- a/src/providers/azure-openai/embed.ts +++ b/src/providers/azure-openai/embed.ts @@ -1,6 +1,6 @@ import { AZURE_OPEN_AI } from '../../globals'; import { EmbedResponse } from '../../types/embedRequestBody'; -import { OpenAIErrorResponseTransform } from '../openai/chatComplete'; +import { OpenAIErrorResponseTransform } from '../openai/utils'; import { ErrorResponse, ProviderConfig } from '../types'; // TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. diff --git a/src/providers/azure-openai/imageGenerate.ts b/src/providers/azure-openai/imageGenerate.ts index ae562cc43..7c0bc3e16 100644 --- a/src/providers/azure-openai/imageGenerate.ts +++ b/src/providers/azure-openai/imageGenerate.ts @@ -1,5 +1,5 @@ import { AZURE_OPEN_AI } from '../../globals'; -import { OpenAIErrorResponseTransform } from '../openai/chatComplete'; +import { OpenAIErrorResponseTransform } from '../openai/utils'; import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from '../types'; export const AzureOpenAIImageGenerateConfig: ProviderConfig = { diff --git a/src/providers/azure-openai/index.ts b/src/providers/azure-openai/index.ts index 590d55fc0..1b0cc7099 100644 --- a/src/providers/azure-openai/index.ts +++ b/src/providers/azure-openai/index.ts @@ -16,6 +16,12 @@ import { AzureOpenAIImageGenerateConfig, AzureOpenAIImageGenerateResponseTransform, } from './imageGenerate'; +import { + AzureOpenAICreateSpeechConfig, + AzureOpenAICreateSpeechResponseTransform, +} from './createSpeech'; +import { AzureOpenAICreateTranscriptionResponseTransform } from './createTranscription'; +import { AzureOpenAICreateTranslationResponseTransform } from './createTranslation'; const AzureOpenAIConfig: ProviderConfigs = { complete: AzureOpenAICompleteConfig, @@ -23,11 +29,17 @@ const AzureOpenAIConfig: ProviderConfigs = { api: AzureOpenAIAPIConfig, imageGenerate: AzureOpenAIImageGenerateConfig, chatComplete: AzureOpenAIChatCompleteConfig, + createSpeech: AzureOpenAICreateSpeechConfig, + createTranscription: {}, + createTranslation: {}, responseTransforms: { complete: AzureOpenAICompleteResponseTransform, chatComplete: AzureOpenAIChatCompleteResponseTransform, embed: AzureOpenAIEmbedResponseTransform, imageGenerate: AzureOpenAIImageGenerateResponseTransform, + createSpeech: AzureOpenAICreateSpeechResponseTransform, + createTranscription: AzureOpenAICreateTranscriptionResponseTransform, + createTranslation: AzureOpenAICreateTranslationResponseTransform, }, }; diff --git a/src/providers/bedrock/api.ts b/src/providers/bedrock/api.ts index 207b48a8e..3c7350a6c 100644 --- a/src/providers/bedrock/api.ts +++ b/src/providers/bedrock/api.ts @@ -1,4 +1,3 @@ -import { Options } from '../../types/requestBody'; import { ProviderAPIConfig } from '../types'; import { generateAWSHeaders } from './utils'; diff --git a/src/providers/fireworks-ai/imageGenerate.ts b/src/providers/fireworks-ai/imageGenerate.ts index ae50d1bc6..d7b876171 100644 --- a/src/providers/fireworks-ai/imageGenerate.ts +++ b/src/providers/fireworks-ai/imageGenerate.ts @@ -86,10 +86,6 @@ interface FireworksAIImageObject { ['X-Fireworks-Billing-Idempotency-Id']: string; } -interface FireworksAIImageGenerateResponse extends ImageGenerateResponse { - data: FireworksAIImageObject[]; -} - export const FireworksAIImageGenerateResponseTransform: ( response: | FireworksAIImageObject[] diff --git a/src/providers/google-vertex-ai/api.ts b/src/providers/google-vertex-ai/api.ts index 1ecb47a0f..17a44c320 100644 --- a/src/providers/google-vertex-ai/api.ts +++ b/src/providers/google-vertex-ai/api.ts @@ -1,22 +1,33 @@ +import { Options } from '../../types/requestBody'; import { ProviderAPIConfig } from '../types'; -import { getModelAndProvider } from './utils'; -import { getAccessToken } from './utils'; +import { getModelAndProvider, getAccessToken } from './utils'; + +const getProjectRoute = ( + providerOptions: Options, + inputModel: string +): string => { + const { + vertexProjectId: inputProjectId, + vertexRegion, + vertexServiceAccountJson, + } = providerOptions; + let projectId = inputProjectId; + if (vertexServiceAccountJson) { + projectId = vertexServiceAccountJson.project_id; + } + + const { provider } = getModelAndProvider(inputModel as string); + const routeVersion = provider === 'meta' ? 'v1beta1' : 'v1'; + return `/${routeVersion}/projects/${projectId}/locations/${vertexRegion}`; +}; // Good reference for using REST: https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstarts/quickstart-multimodal#gemini-beginner-samples-drest // Difference versus Studio AI: https://cloud.google.com/vertex-ai/docs/start/ai-platform-users export const GoogleApiConfig: ProviderAPIConfig = { getBaseURL: ({ providerOptions }) => { - const { - vertexProjectId: inputProjectId, - vertexRegion, - vertexServiceAccountJson, - } = providerOptions; - let projectId = inputProjectId; - if (vertexServiceAccountJson) { - projectId = vertexServiceAccountJson.project_id; - } + const { vertexRegion } = providerOptions; - return `https://${vertexRegion}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${vertexRegion}`; + return `https://${vertexRegion}-aiplatform.googleapis.com`; }, headers: async ({ providerOptions }) => { const { apiKey, vertexServiceAccountJson } = providerOptions; @@ -30,7 +41,7 @@ export const GoogleApiConfig: ProviderAPIConfig = { Authorization: `Bearer ${authToken}`, }; }, - getEndpoint: ({ fn, gatewayRequestBody }) => { + getEndpoint: ({ fn, gatewayRequestBody, providerOptions }) => { let mappedFn = fn; const { model: inputModel, stream } = gatewayRequestBody; if (stream) { @@ -38,28 +49,33 @@ export const GoogleApiConfig: ProviderAPIConfig = { } const { provider, model } = getModelAndProvider(inputModel as string); + const projectRoute = getProjectRoute(providerOptions, inputModel as string); switch (provider) { case 'google': { if (mappedFn === 'chatComplete') { - return `/publishers/${provider}/models/${model}:generateContent`; + return `${projectRoute}/publishers/${provider}/models/${model}:generateContent`; } else if (mappedFn === 'stream-chatComplete') { - return `/publishers/${provider}/models/${model}:streamGenerateContent?alt=sse`; + return `${projectRoute}/publishers/${provider}/models/${model}:streamGenerateContent?alt=sse`; } } case 'anthropic': { if (mappedFn === 'chatComplete') { - return `/publishers/${provider}/models/${model}:rawPredict`; + return `${projectRoute}/publishers/${provider}/models/${model}:rawPredict`; } else if (mappedFn === 'stream-chatComplete') { - return `/publishers/${provider}/models/${model}:streamRawPredict`; + return `${projectRoute}/publishers/${provider}/models/${model}:streamRawPredict`; } } + case 'meta': { + return `${projectRoute}/endpoints/openapi/chat/completions`; + } + // Embed API is not yet implemented in the gateway // This may be as easy as copy-paste from Google provider, but needs to be tested default: - return ''; + return `${projectRoute}`; } }, }; diff --git a/src/providers/google-vertex-ai/chatComplete.ts b/src/providers/google-vertex-ai/chatComplete.ts index 66bd8a11c..8f8d8d1ea 100644 --- a/src/providers/google-vertex-ai/chatComplete.ts +++ b/src/providers/google-vertex-ai/chatComplete.ts @@ -34,6 +34,8 @@ import { transformGenerationConfig } from './transformGenerationConfig'; import type { GoogleErrorResponse, GoogleGenerateContentResponse, + VertexLlamaChatCompleteStreamChunk, + VertexLLamaChatCompleteResponse, } from './types'; export const VertexGoogleChatCompleteConfig: ProviderConfig = { @@ -643,6 +645,47 @@ export const GoogleChatCompleteResponseTransform: ( return generateInvalidProviderResponseError(response, GOOGLE_VERTEX_AI); }; +export const VertexLlamaChatCompleteConfig: ProviderConfig = { + model: { + param: 'model', + required: true, + default: 'meta/llama3-405b-instruct-maas', + }, + messages: { + param: 'messages', + required: true, + default: [], + }, + max_tokens: { + param: 'max_tokens', + default: 512, + min: 1, + max: 2048, + }, + temperature: { + param: 'temperature', + default: 0.5, + min: 0, + max: 1, + }, + top_p: { + param: 'top_p', + default: 0.9, + min: 0, + max: 1, + }, + top_k: { + param: 'top_k', + default: 0, + min: 0, + max: 2048, + }, + stream: { + param: 'stream', + default: false, + }, +}; + export const GoogleChatCompleteStreamChunkTransform: ( response: string, fallbackId: string, @@ -935,3 +978,50 @@ export const VertexAnthropicChatCompleteStreamChunkTransform: ( })}` + '\n\n' ); }; + +export const VertexLlamaChatCompleteResponseTransform: ( + response: VertexLLamaChatCompleteResponse | GoogleErrorResponse, + responseStatus: number +) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { + if ( + responseStatus !== 200 && + Array.isArray(response) && + response.length > 0 && + 'error' in response[0] + ) { + const { error } = response[0]; + + return generateErrorResponse( + { + message: error.message, + type: error.status, + param: null, + code: String(error.code), + }, + GOOGLE_VERTEX_AI + ); + } + if ('choices' in response) { + return { + id: crypto.randomUUID(), + created: Math.floor(Date.now() / 1000), + provider: GOOGLE_VERTEX_AI, + ...response, + }; + } + return generateInvalidProviderResponseError(response, GOOGLE_VERTEX_AI); +}; + +export const VertexLlamaChatCompleteStreamChunkTransform: ( + response: string, + fallbackId: string +) => string = (responseChunk, fallbackId) => { + let chunk = responseChunk.trim(); + chunk = chunk.replace(/^data: /, ''); + chunk = chunk.trim(); + const parsedChunk: VertexLlamaChatCompleteStreamChunk = JSON.parse(chunk); + parsedChunk.id = fallbackId; + parsedChunk.created = Math.floor(Date.now() / 1000); + parsedChunk.provider = GOOGLE_VERTEX_AI; + return `data: ${JSON.stringify(parsedChunk)}` + '\n\n'; +}; diff --git a/src/providers/google-vertex-ai/index.ts b/src/providers/google-vertex-ai/index.ts index 5a7eae93a..cc3cf2404 100644 --- a/src/providers/google-vertex-ai/index.ts +++ b/src/providers/google-vertex-ai/index.ts @@ -7,6 +7,9 @@ import { VertexAnthropicChatCompleteResponseTransform, VertexAnthropicChatCompleteStreamChunkTransform, VertexGoogleChatCompleteConfig, + VertexLlamaChatCompleteConfig, + VertexLlamaChatCompleteResponseTransform, + VertexLlamaChatCompleteStreamChunkTransform, } from './chatComplete'; import { getModelAndProvider } from './utils'; @@ -36,6 +39,15 @@ const VertexConfig: ProviderConfigs = { chatComplete: VertexAnthropicChatCompleteResponseTransform, }, }; + case 'meta': + return { + chatComplete: VertexLlamaChatCompleteConfig, + api: GoogleApiConfig, + responseTransforms: { + chatComplete: VertexLlamaChatCompleteResponseTransform, + 'stream-chatComplete': VertexLlamaChatCompleteStreamChunkTransform, + }, + }; } }, }; diff --git a/src/providers/google-vertex-ai/types.ts b/src/providers/google-vertex-ai/types.ts index 8d054df15..f0c2c80f5 100644 --- a/src/providers/google-vertex-ai/types.ts +++ b/src/providers/google-vertex-ai/types.ts @@ -1,3 +1,5 @@ +import { ChatCompletionResponse } from '../types'; + export interface GoogleErrorResponse { error: { code: number; @@ -42,3 +44,27 @@ export interface GoogleGenerateContentResponse { totalTokenCount: number; }; } + +export interface VertexLLamaChatCompleteResponse + extends Omit {} + +export interface VertexLlamaChatCompleteStreamChunk { + choices: { + delta: { + content: string; + role: string; + }; + finish_reason?: string; + index: 0; + }[]; + model: string; + object: string; + usage?: { + completion_tokens: number; + prompt_tokens: number; + total_tokens: number; + }; + id?: string; + created?: number; + provider?: string; +} diff --git a/src/providers/google-vertex-ai/utils.ts b/src/providers/google-vertex-ai/utils.ts index 6957b32fe..a4f909df9 100644 --- a/src/providers/google-vertex-ai/utils.ts +++ b/src/providers/google-vertex-ai/utils.ts @@ -129,6 +129,8 @@ export const getModelAndProvider = (modelString: string) => { ) { provider = modelStringParts[0]; model = modelStringParts.slice(1).join('.'); + } else if (modelString.includes('llama')) { + provider = 'meta'; } return { provider, model }; diff --git a/src/providers/google/chatComplete.ts b/src/providers/google/chatComplete.ts index 91500c0d0..14287a1f6 100644 --- a/src/providers/google/chatComplete.ts +++ b/src/providers/google/chatComplete.ts @@ -387,7 +387,7 @@ export const GoogleChatCompleteResponseTransform: ( model: 'Unknown', provider: 'google', choices: - response.candidates?.map((generation, index) => { + response.candidates?.map((generation) => { let message: Message = { role: 'assistant', content: '' }; if (generation.content.parts[0]?.text) { message = { @@ -459,7 +459,7 @@ export const GoogleChatCompleteStreamChunkTransform: ( model: '', provider: 'google', choices: - parsedChunk.candidates?.map((generation, index) => { + parsedChunk.candidates?.map((generation) => { let message: Message = { role: 'assistant', content: '' }; if (generation.content.parts[0]?.text) { message = { diff --git a/src/providers/ollama/embed.ts b/src/providers/ollama/embed.ts index fcca9f70d..017f283b6 100644 --- a/src/providers/ollama/embed.ts +++ b/src/providers/ollama/embed.ts @@ -28,7 +28,7 @@ interface OllamaErrorResponse { export const OllamaEmbedResponseTransform: ( response: OllamaEmbedResponse | OllamaErrorResponse, responseStatus: number -) => EmbedResponse | ErrorResponse = (response, responseStatus) => { +) => EmbedResponse | ErrorResponse = (response) => { if ('error' in response) { return generateErrorResponse( { message: response.error, type: null, param: null, code: null }, diff --git a/src/providers/openai/api.ts b/src/providers/openai/api.ts index 70725986e..9c2a251f2 100644 --- a/src/providers/openai/api.ts +++ b/src/providers/openai/api.ts @@ -2,7 +2,7 @@ import { ProviderAPIConfig } from '../types'; const OpenAIAPIConfig: ProviderAPIConfig = { getBaseURL: () => 'https://api.openai.com/v1', - headers: ({ providerOptions }) => { + headers: ({ providerOptions, fn }) => { const headersObj: Record = { Authorization: `Bearer ${providerOptions.apiKey}`, }; @@ -14,6 +14,9 @@ const OpenAIAPIConfig: ProviderAPIConfig = { headersObj['OpenAI-Project'] = providerOptions.openaiProject; } + if (fn === 'createTranscription' || fn === 'createTranslation') + headersObj['Content-Type'] = 'multipart/form-data'; + return headersObj; }, getEndpoint: ({ fn }) => { @@ -26,6 +29,12 @@ const OpenAIAPIConfig: ProviderAPIConfig = { return '/embeddings'; case 'imageGenerate': return '/images/generations'; + case 'createSpeech': + return '/audio/speech'; + case 'createTranscription': + return '/audio/transcriptions'; + case 'createTranslation': + return '/audio/translations'; default: return ''; } diff --git a/src/providers/openai/chatComplete.ts b/src/providers/openai/chatComplete.ts index dfef3017e..5e7047306 100644 --- a/src/providers/openai/chatComplete.ts +++ b/src/providers/openai/chatComplete.ts @@ -4,7 +4,7 @@ import { ErrorResponse, ProviderConfig, } from '../types'; -import { generateErrorResponse } from '../utils'; +import { OpenAIErrorResponseTransform } from './utils'; // TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. @@ -96,18 +96,6 @@ export interface OpenAIChatCompleteResponse extends ChatCompletionResponse { system_fingerprint: string; } -export const OpenAIErrorResponseTransform: ( - response: ErrorResponse, - provider: string -) => ErrorResponse = (response, provider) => { - return generateErrorResponse( - { - ...response.error, - }, - provider - ); -}; - export const OpenAIChatCompleteResponseTransform: ( response: OpenAIChatCompleteResponse | ErrorResponse, responseStatus: number diff --git a/src/providers/openai/complete.ts b/src/providers/openai/complete.ts index 060aa0df1..d6242a3dc 100644 --- a/src/providers/openai/complete.ts +++ b/src/providers/openai/complete.ts @@ -1,6 +1,6 @@ import { OPEN_AI } from '../../globals'; import { CompletionResponse, ErrorResponse, ProviderConfig } from '../types'; -import { OpenAIErrorResponseTransform } from './chatComplete'; +import { OpenAIErrorResponseTransform } from './utils'; // TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. diff --git a/src/providers/openai/createSpeech.ts b/src/providers/openai/createSpeech.ts new file mode 100644 index 000000000..93ca172b9 --- /dev/null +++ b/src/providers/openai/createSpeech.ts @@ -0,0 +1,41 @@ +import { OPEN_AI } from '../../globals'; +import { ErrorResponse, ProviderConfig } from '../types'; +import { OpenAIErrorResponseTransform } from './utils'; + +export const OpenAICreateSpeechConfig: ProviderConfig = { + model: { + param: 'model', + required: true, + default: 'tts-1', + }, + input: { + param: 'input', + required: true, + }, + voice: { + param: 'voice', + required: true, + default: 'alloy', + }, + response_format: { + param: 'response_format', + required: false, + default: 'mp3', + }, + speed: { + param: 'speed', + required: false, + default: 1, + }, +}; + +export const OpenAICreateSpeechResponseTransform: ( + response: Response | ErrorResponse, + responseStatus: number +) => Response | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200 && 'error' in response) { + return OpenAIErrorResponseTransform(response, OPEN_AI); + } + + return response; +}; diff --git a/src/providers/openai/createTranscription.ts b/src/providers/openai/createTranscription.ts new file mode 100644 index 000000000..146060ee1 --- /dev/null +++ b/src/providers/openai/createTranscription.ts @@ -0,0 +1,14 @@ +import { OPEN_AI } from '../../globals'; +import { ErrorResponse } from '../types'; +import { OpenAIErrorResponseTransform } from './utils'; + +export const OpenAICreateTranscriptionResponseTransform: ( + response: Response | ErrorResponse, + responseStatus: number +) => Response | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200 && 'error' in response) { + return OpenAIErrorResponseTransform(response, OPEN_AI); + } + + return response; +}; diff --git a/src/providers/openai/createTranslation.ts b/src/providers/openai/createTranslation.ts new file mode 100644 index 000000000..51b07554f --- /dev/null +++ b/src/providers/openai/createTranslation.ts @@ -0,0 +1,14 @@ +import { OPEN_AI } from '../../globals'; +import { ErrorResponse } from '../types'; +import { OpenAIErrorResponseTransform } from './utils'; + +export const OpenAICreateTranslationResponseTransform: ( + response: Response | ErrorResponse, + responseStatus: number +) => Response | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200 && 'error' in response) { + return OpenAIErrorResponseTransform(response, OPEN_AI); + } + + return response; +}; diff --git a/src/providers/openai/embed.ts b/src/providers/openai/embed.ts index 80a572d20..6d929383c 100644 --- a/src/providers/openai/embed.ts +++ b/src/providers/openai/embed.ts @@ -1,7 +1,7 @@ import { OPEN_AI } from '../../globals'; import { EmbedResponse } from '../../types/embedRequestBody'; import { ErrorResponse, ProviderConfig } from '../types'; -import { OpenAIErrorResponseTransform } from './chatComplete'; +import { OpenAIErrorResponseTransform } from './utils'; // TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. diff --git a/src/providers/openai/imageGenerate.ts b/src/providers/openai/imageGenerate.ts index bbb75f68b..bbc01491d 100644 --- a/src/providers/openai/imageGenerate.ts +++ b/src/providers/openai/imageGenerate.ts @@ -1,6 +1,6 @@ import { OPEN_AI } from '../../globals'; import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from '../types'; -import { OpenAIErrorResponseTransform } from './chatComplete'; +import { OpenAIErrorResponseTransform } from './utils'; export const OpenAIImageGenerateConfig: ProviderConfig = { prompt: { diff --git a/src/providers/openai/index.ts b/src/providers/openai/index.ts index bd0886d71..7fcb1b0fa 100644 --- a/src/providers/openai/index.ts +++ b/src/providers/openai/index.ts @@ -13,6 +13,12 @@ import { OpenAIImageGenerateConfig, OpenAIImageGenerateResponseTransform, } from './imageGenerate'; +import { + OpenAICreateSpeechConfig, + OpenAICreateSpeechResponseTransform, +} from './createSpeech'; +import { OpenAICreateTranscriptionResponseTransform } from './createTranscription'; +import { OpenAICreateTranslationResponseTransform } from './createTranslation'; const OpenAIConfig: ProviderConfigs = { complete: OpenAICompleteConfig, @@ -20,6 +26,9 @@ const OpenAIConfig: ProviderConfigs = { api: OpenAIAPIConfig, chatComplete: OpenAIChatCompleteConfig, imageGenerate: OpenAIImageGenerateConfig, + createSpeech: OpenAICreateSpeechConfig, + createTranscription: {}, + createTranslation: {}, responseTransforms: { complete: OpenAICompleteResponseTransform, // 'stream-complete': OpenAICompleteResponseTransform, @@ -27,6 +36,9 @@ const OpenAIConfig: ProviderConfigs = { // 'stream-chatComplete': OpenAIChatCompleteResponseTransform, embed: OpenAIEmbedResponseTransform, imageGenerate: OpenAIImageGenerateResponseTransform, + createSpeech: OpenAICreateSpeechResponseTransform, + createTranscription: OpenAICreateTranscriptionResponseTransform, + createTranslation: OpenAICreateTranslationResponseTransform, }, }; diff --git a/src/providers/openai/utils.ts b/src/providers/openai/utils.ts new file mode 100644 index 000000000..da6ed6f9b --- /dev/null +++ b/src/providers/openai/utils.ts @@ -0,0 +1,14 @@ +import { ErrorResponse } from '../types'; +import { generateErrorResponse } from '../utils'; + +export const OpenAIErrorResponseTransform: ( + response: ErrorResponse, + provider: string +) => ErrorResponse = (response, provider) => { + return generateErrorResponse( + { + ...response.error, + }, + provider + ); +}; diff --git a/src/providers/palm/embed.ts b/src/providers/palm/embed.ts index 0bf046a9a..cb8c8bb74 100644 --- a/src/providers/palm/embed.ts +++ b/src/providers/palm/embed.ts @@ -1,5 +1,5 @@ import { PALM } from '../../globals'; -import { EmbedParams, EmbedResponse } from '../../types/embedRequestBody'; +import { EmbedResponse } from '../../types/embedRequestBody'; import { GoogleErrorResponse, GoogleErrorResponseTransform, diff --git a/src/providers/perplexity-ai/chatComplete.ts b/src/providers/perplexity-ai/chatComplete.ts index a8d1992df..af1aa993e 100644 --- a/src/providers/perplexity-ai/chatComplete.ts +++ b/src/providers/perplexity-ai/chatComplete.ts @@ -111,7 +111,7 @@ export interface PerplexityAIChatCompletionStreamChunk { export const PerplexityAIChatCompleteResponseTransform: ( response: PerplexityAIChatCompleteResponse | PerplexityAIErrorResponse, responseStatus: number -) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { +) => ChatCompletionResponse | ErrorResponse = (response) => { if ('error' in response) { return generateErrorResponse( { diff --git a/src/providers/reka-ai/chatComplete.ts b/src/providers/reka-ai/chatComplete.ts index d49e0066b..8f046b9e0 100644 --- a/src/providers/reka-ai/chatComplete.ts +++ b/src/providers/reka-ai/chatComplete.ts @@ -1,5 +1,5 @@ import { REKA_AI } from '../../globals'; -import { Message, Params } from '../../types/requestBody'; +import { Params } from '../../types/requestBody'; import { ChatCompletionResponse, ErrorResponse, @@ -145,7 +145,7 @@ export interface RekaAIErrorResponse { export const RekaAIChatCompleteResponseTransform: ( response: RekaAIChatCompleteResponse | RekaAIErrorResponse, responseStatus: number -) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { +) => ChatCompletionResponse | ErrorResponse = (response) => { if ('detail' in response) { return generateErrorResponse( { diff --git a/src/providers/types.ts b/src/providers/types.ts index 93253e481..5f647a22e 100644 --- a/src/providers/types.ts +++ b/src/providers/types.ts @@ -59,7 +59,10 @@ export type endpointStrings = | 'stream-complete' | 'stream-chatComplete' | 'proxy' - | 'imageGenerate'; + | 'imageGenerate' + | 'createSpeech' + | 'createTranscription' + | 'createTranslation'; /** * A collection of API configurations for multiple AI providers. diff --git a/src/providers/workers-ai/api.ts b/src/providers/workers-ai/api.ts index 0f648a3d7..f97f2bb25 100644 --- a/src/providers/workers-ai/api.ts +++ b/src/providers/workers-ai/api.ts @@ -9,7 +9,7 @@ const WorkersAiAPIConfig: ProviderAPIConfig = { const { apiKey } = providerOptions; return { Authorization: `Bearer ${apiKey}` }; }, - getEndpoint: ({ providerOptions, fn, gatewayRequestBody: params }) => { + getEndpoint: ({ fn, gatewayRequestBody: params }) => { const { model } = params; switch (fn) { case 'complete': { diff --git a/src/providers/workers-ai/chatComplete.ts b/src/providers/workers-ai/chatComplete.ts index ae88c3850..64acb7fb3 100644 --- a/src/providers/workers-ai/chatComplete.ts +++ b/src/providers/workers-ai/chatComplete.ts @@ -1,5 +1,4 @@ import { WORKERS_AI } from '../../globals'; -import { Params, Message } from '../../types/requestBody'; import { ChatCompletionResponse, ErrorResponse, diff --git a/src/services/conditionalRouter.ts b/src/services/conditionalRouter.ts new file mode 100644 index 000000000..43553591b --- /dev/null +++ b/src/services/conditionalRouter.ts @@ -0,0 +1,152 @@ +import { StrategyModes, Targets } from '../types/requestBody'; + +type Query = { + [key: string]: any; +}; + +interface RouterContext { + metadata?: Record; +} + +enum Operator { + // Comparison Operators + Equal = '$eq', + NotEqual = '$ne', + GreaterThan = '$gt', + GreaterThanOrEqual = '$gte', + LessThan = '$lt', + LessThanOrEqual = '$lte', + In = '$in', + NotIn = '$nin', + Regex = '$regex', + + // Logical Operators + And = '$and', + Or = '$or', +} + +export class ConditionalRouter { + private config: Targets; + private context: RouterContext; + + constructor(config: Targets, context: RouterContext) { + this.config = config; + this.context = context; + if (this.config.strategy?.mode !== StrategyModes.CONDITIONAL) { + throw new Error('Unsupported strategy mode'); + } + } + + resolveTarget(): Targets { + if (!this.config.strategy?.conditions) { + throw new Error('No conditions passed in the query router'); + } + + for (const condition of this.config.strategy.conditions) { + if (this.evaluateQuery(condition.query)) { + const targetName = condition.then; + return this.findTarget(targetName); + } + } + + // If no conditions matched and a default is specified, return the default target + if (this.config.strategy.default) { + return this.findTarget(this.config.strategy.default); + } + + throw new Error('Query router did not resolve to any valid target'); + } + + private evaluateQuery(query: Query): boolean { + for (const [key, value] of Object.entries(query)) { + if (key === Operator.Or && Array.isArray(value)) { + return value.some((subCondition: Query) => + this.evaluateQuery(subCondition) + ); + } + + if (key === Operator.And && Array.isArray(value)) { + return value.every((subCondition: Query) => + this.evaluateQuery(subCondition) + ); + } + + const metadataValue = this.getContextValue(key); + + if (typeof value === 'object' && value !== null) { + if (!this.evaluateOperator(value, metadataValue)) { + return false; + } + } else if (metadataValue !== value) { + return false; + } + } + + return true; + } + + private evaluateOperator(operator: string, value: any): boolean { + for (const [op, compareValue] of Object.entries(operator)) { + switch (op) { + case Operator.Equal: + if (value !== compareValue) return false; + break; + case Operator.NotEqual: + if (value === compareValue) return false; + break; + case Operator.GreaterThan: + if (!(parseFloat(value) > parseFloat(compareValue))) return false; + break; + case Operator.GreaterThanOrEqual: + if (!(parseFloat(value) >= parseFloat(compareValue))) return false; + break; + case Operator.LessThan: + if (!(parseFloat(value) < parseFloat(compareValue))) return false; + break; + case Operator.LessThanOrEqual: + if (!(parseFloat(value) <= parseFloat(compareValue))) return false; + break; + case Operator.In: + if (!Array.isArray(compareValue) || !compareValue.includes(value)) + return false; + break; + case Operator.NotIn: + if (!Array.isArray(compareValue) || compareValue.includes(value)) + return false; + break; + case Operator.Regex: + try { + const regex = new RegExp(compareValue); + return regex.test(value); + } catch (e) { + return false; + } + default: + throw new Error( + `Unsupported operator used in the query router: ${op}` + ); + } + } + return true; + } + + private findTarget(name: string): Targets { + const index = + this.config.targets?.findIndex((target) => target.name === name) ?? -1; + if (index === -1) { + throw new Error(`Invalid target name found in the query router: ${name}`); + } + + return { + ...this.config.targets?.[index], + index, + }; + } + + private getContextValue(key: string): any { + const parts = key.split('.'); + let value: any = this.context; + value = value[parts[0]]?.[parts[1]]; + return value; + } +} diff --git a/src/services/transformToProviderRequest.ts b/src/services/transformToProviderRequest.ts index b2684a758..631d48dd2 100644 --- a/src/services/transformToProviderRequest.ts +++ b/src/services/transformToProviderRequest.ts @@ -1,4 +1,6 @@ +import { MULTIPART_FORM_DATA_ENDPOINTS } from '../globals'; import ProviderConfigs from '../providers'; +import { endpointStrings } from '../providers/types'; import { Params } from '../types/requestBody'; /** @@ -20,16 +22,6 @@ function setNestedProperty(obj: any, path: string, value: any) { current[parts[parts.length - 1]] = value; } -function setArrayNestedProperties( - obj: any, - path: Array, - value: Array -) { - for (let i = 0; i < path.length; i++) { - setNestedProperty(obj, path[i], value[i]); - } -} - /** * Transforms the request body to match the structure required by the AI provider. * It also ensures the values for each parameter are within the minimum and maximum @@ -44,7 +36,7 @@ function setArrayNestedProperties( * * @throws {Error} If the provider is not supported. */ -const transformToProviderRequest = ( +const transformToProviderRequestJSON = ( provider: string, params: Params, fn: string @@ -140,4 +132,24 @@ const transformToProviderRequest = ( return transformedRequest; }; +/** + * Transforms the request parameters to the format expected by the provider. + * + * @param {string} provider - The name of the provider (e.g., 'openai', 'anthropic'). + * @param {Params} params - The parameters for the request. + * @param {Params | FormData} inputParams - The original input parameters. + * @param {endpointStrings} fn - The function endpoint being called (e.g., 'complete', 'chatComplete'). + * @returns {Params | FormData} - The transformed request parameters. + */ +export const transformToProviderRequest = ( + provider: string, + params: Params, + inputParams: Params | FormData, + fn: endpointStrings +) => { + return MULTIPART_FORM_DATA_ENDPOINTS.includes(fn) + ? inputParams + : transformToProviderRequestJSON(provider, params as Params, fn); +}; + export default transformToProviderRequest; diff --git a/src/tests/common.test.ts b/src/tests/common.test.ts new file mode 100644 index 000000000..7964e3339 --- /dev/null +++ b/src/tests/common.test.ts @@ -0,0 +1,18 @@ +import Providers from '../providers'; +import testVariables from './resources/testVariables'; +import { executeChatCompletionEndpointTests } from './routeSpecificTestFunctions.ts/chatCompletion'; + +for (const provider in testVariables) { + const variables = testVariables[provider]; + const config = Providers[provider]; + + if (!variables.apiKey) { + console.log(`Skipping ${provider} as API key is not provided`); + continue; + } + + if (config.chatComplete) { + describe(`${provider} /chat/completions endpoint tests:`, () => + executeChatCompletionEndpointTests(provider, variables)); + } +} diff --git a/src/tests/resources/constants.ts b/src/tests/resources/constants.ts new file mode 100644 index 000000000..1aaea3d5f --- /dev/null +++ b/src/tests/resources/constants.ts @@ -0,0 +1,2 @@ +const baseURL = 'http://localhost'; +export const CHAT_COMPLETIONS_ENDPOINT = `${baseURL}/v1/chat/completions`; diff --git a/src/tests/resources/requestTemplates.ts b/src/tests/resources/requestTemplates.ts new file mode 100644 index 000000000..ae3f2c43f --- /dev/null +++ b/src/tests/resources/requestTemplates.ts @@ -0,0 +1,54 @@ +import { Params } from '../../types/requestBody'; + +const CHAT_COMPLETE_WITH_MESSAGE_CONTENT_ARRAYS_REQUEST: Params = { + model: 'MODEL_PLACE_HOLDER', + max_tokens: 20, + stream: false, + messages: [ + { + role: 'system', + content: 'You are the half-blood prince', + }, + { + role: 'user', + content: [ + { + type: 'text', + text: 'Can you teach me a useful spell?', + }, + ], + }, + ], +}; + +export const getChatCompleteWithMessageContentArraysRequest = ( + model?: string +) => { + return JSON.stringify({ + ...CHAT_COMPLETE_WITH_MESSAGE_CONTENT_ARRAYS_REQUEST, + model, + }); +}; + +export const CHAT_COMPLETE_WITH_MESSAGE_STRING_REQUEST: Params = { + model: 'MODEL_PLACEHOLDER', + max_tokens: 20, + stream: false, + messages: [ + { + role: 'system', + content: 'You are the half-blood prince', + }, + { + role: 'user', + content: 'Can you teach me a useful spell?', + }, + ], +}; + +export const getChatCompleteWithMessageStringRequest = (model?: string) => { + return JSON.stringify({ + ...CHAT_COMPLETE_WITH_MESSAGE_STRING_REQUEST, + model, + }); +}; diff --git a/src/tests/resources/testVariables.ts b/src/tests/resources/testVariables.ts new file mode 100644 index 000000000..ab0092007 --- /dev/null +++ b/src/tests/resources/testVariables.ts @@ -0,0 +1,141 @@ +import Providers from '../../providers'; + +export interface TestVariable { + apiKey?: string; + chatCompletions?: { + model: string; + }; +} + +export interface TestVariables { + [key: keyof typeof Providers]: TestVariable; +} + +const testVariables: TestVariables = { + openai: { + apiKey: process.env.OPENAI_API_KEY, + chatCompletions: { model: 'gpt-3.5-turbo' }, + }, + cohere: { + apiKey: process.env.COHERE_API_KEY, + chatCompletions: { model: 'command-r-plus' }, + }, + anthropic: { + apiKey: process.env.ANTHROPIC_API_KEY, + chatCompletions: { + model: 'claude-3-opus-20240229', + }, + }, + 'azure-openai': { + apiKey: process.env.AZURE_OPENAI_API_KEY, + chatCompletions: { model: '' }, + }, + anyscale: { + apiKey: process.env.ANYSCALE_API_KEY, + chatCompletions: { model: 'j2-light' }, + }, + palm: { + apiKey: process.env.PALM_API_KEY, + chatCompletions: { model: '' }, + }, + 'together-ai': { + apiKey: process.env.TOGETHER_AI_API_KEY, + chatCompletions: { model: 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo' }, + }, + google: { + apiKey: process.env.GOOGLE_API_KEY, + chatCompletions: { model: 'gemini-1.5-flash' }, + }, + 'vertex-ai': { + apiKey: process.env.VERTEX_AI_API_KEY, + chatCompletions: { model: '' }, + }, + 'perplexity-ai': { + apiKey: process.env.PERPLEXITY_AI_API_KEY, + chatCompletions: { model: 'llama-3-sonar-small-32k-online' }, + }, + 'mistral-ai': { + apiKey: process.env.MISTRAL_AI_API_KEY, + chatCompletions: { model: 'open-mistral-nemo' }, + }, + deepinfra: { + apiKey: process.env.DEEPINFRA_API_KEY, + chatCompletions: { model: 'meta-llama/Meta-Llama-3-8B-Instruct' }, + }, + 'stability-ai': { + apiKey: process.env.STABILITY_AI_API_KEY, + chatCompletions: { model: '' }, + }, + nomic: { + apiKey: process.env.NOMIC_API_KEY, + chatCompletions: { model: '' }, + }, + ollama: { + apiKey: process.env.OLLAMA_API_KEY, + chatCompletions: { model: '' }, + }, + ai21: { + apiKey: process.env.AI21_API_KEY, + chatCompletions: { + model: 'j2-ultra', + }, + }, + bedrock: { + apiKey: process.env.BEDROCK_API_KEY, + chatCompletions: { model: '' }, + }, + groq: { + apiKey: process.env.GROQ_API_KEY, + chatCompletions: { model: 'llama3-8b-8192' }, + }, + segmind: { + apiKey: process.env.SEGMIND_API_KEY, + chatCompletions: { model: '' }, + }, + jina: { + apiKey: process.env.JINA_API_KEY, + chatCompletions: { model: '' }, + }, + 'fireworks-ai': { + apiKey: process.env.FIREWORKS_AI_API_KEY, + chatCompletions: { model: '' }, + }, + 'workers-ai': { + apiKey: process.env.WORKERS_AI_API_KEY, + chatCompletions: { model: '' }, + }, + 'reka-ai': { + apiKey: process.env.REKA_AI_API_KEY, + chatCompletions: { model: '' }, + }, + moonshot: { + apiKey: process.env.MOONSHOT_API_KEY, + chatCompletions: { model: '' }, + }, + openrouter: { + apiKey: process.env.OPENROUTER_API_KEY, + chatCompletions: { model: 'meta-llama/llama-3.1-8b-instruct:free' }, + }, + lingyi: { + apiKey: process.env.LINGYI_API_KEY, + chatCompletions: { model: '' }, + }, + zhipu: { + apiKey: process.env.ZHIPU_API_KEY, + chatCompletions: { model: '' }, + }, + 'novita-ai': { + apiKey: process.env.NOVITA_AI_API_KEY, + chatCompletions: { model: '' }, + }, + monsterapi: { + apiKey: process.env.MONSTERAPI_API_KEY, + chatCompletions: { model: 'meta-llama/Meta-Llama-3-8B-Instruct' }, + }, + predibase: { + apiKey: process.env.PREDIBASE_API_KEY, + chatCompletions: { model: '' }, + }, +}; + +export default testVariables; diff --git a/src/tests/resources/utils.ts b/src/tests/resources/utils.ts new file mode 100644 index 000000000..ee22d56ec --- /dev/null +++ b/src/tests/resources/utils.ts @@ -0,0 +1,10 @@ +export const createDefaultHeaders = ( + provider: string, + authorization: string +) => { + return { + 'x-portkey-provider': provider, + Authorization: authorization, + 'Content-Type': 'application/json', + }; +}; diff --git a/src/tests/routeSpecificTestFunctions.ts/chatCompletion.ts b/src/tests/routeSpecificTestFunctions.ts/chatCompletion.ts new file mode 100644 index 000000000..8bb9dc6e3 --- /dev/null +++ b/src/tests/routeSpecificTestFunctions.ts/chatCompletion.ts @@ -0,0 +1,42 @@ +import app from '../..'; +import { CHAT_COMPLETIONS_ENDPOINT } from '../resources/constants'; +import { + getChatCompleteWithMessageContentArraysRequest, + getChatCompleteWithMessageStringRequest, +} from '../resources/requestTemplates'; +import { TestVariable } from '../resources/testVariables'; +import { createDefaultHeaders } from '../resources/utils'; + +export const executeChatCompletionEndpointTests: ( + providerName: string, + providerVariables: TestVariable +) => void = (providerName, providerVariables) => { + const model = providerVariables.chatCompletions?.model; + const apiKey = providerVariables.apiKey; + if (!model || !apiKey) { + console.warn( + `Skipping ${providerName} as it does not have chat completions options` + ); + return; + } + + test(`${providerName} /chat/completions test message strings`, async () => { + const request = new Request(CHAT_COMPLETIONS_ENDPOINT, { + method: 'POST', + headers: createDefaultHeaders(providerName, apiKey), + body: getChatCompleteWithMessageStringRequest(model), + }); + const res = await app.fetch(request); + expect(res.status).toBe(200); + }); + + test(`${providerName} /chat/completions test message content arrays`, async () => { + const request = new Request(CHAT_COMPLETIONS_ENDPOINT, { + method: 'POST', + headers: createDefaultHeaders(providerName, apiKey), + body: getChatCompleteWithMessageContentArraysRequest(model), + }); + const res = await app.fetch(request); + expect(res.status).toBe(200); + }); +}; diff --git a/src/types/requestBody.ts b/src/types/requestBody.ts index 5821b5a0d..38b1d1a2f 100644 --- a/src/types/requestBody.ts +++ b/src/types/requestBody.ts @@ -16,9 +16,23 @@ interface CacheSettings { maxAge?: number; } +export enum StrategyModes { + LOADBALANCE = 'loadbalance', + FALLBACK = 'fallback', + SINGLE = 'single', + CONDITIONAL = 'conditional', +} + interface Strategy { - mode: string; + mode: StrategyModes; onStatusCodes?: Array; + conditions?: { + query: { + [key: string]: any; + }; + then: string; + }[]; + default?: string; } /** @@ -83,6 +97,7 @@ export interface Options { * @interface */ export interface Targets { + name?: string; strategy?: Strategy; /** The name of the provider. */ provider?: string | undefined; diff --git a/src/utils.ts b/src/utils.ts index 7bba2b694..184d534bb 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -5,7 +5,6 @@ import { GOOGLE_VERTEX_AI, PERPLEXITY_AI, DEEPINFRA, - OLLAMA, } from './globals'; import { Params } from './types/requestBody'; @@ -27,10 +26,10 @@ export const getStreamModeSplitPattern = ( splitPattern = '\r\n'; } - // Anthropic vertex has \n\n as the pattern + // In Vertex Anthropic and LLama have \n\n as the pattern only Gemini has \r\n\r\n if ( proxyProvider === GOOGLE_VERTEX_AI && - requestURL.indexOf('/publishers/anthropic') === -1 + requestURL.includes('/publishers/google') ) { splitPattern = '\r\n\r\n'; }