From 23958bd82922e4c09c662714c9f1b2db8b6f5dc9 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Mon, 14 Oct 2024 17:06:00 -0700 Subject: [PATCH] fix(playground): plumb through model name and providers (#4999) Co-authored-by: Mikyo King --- app/schema.graphql | 25 +++- .../pages/playground/ModelConfigButton.tsx | 19 +-- app/src/pages/playground/ModelPicker.tsx | 21 +-- .../pages/playground/ModelProviderPicker.tsx | 21 ++- .../PlaygroundCredentialsDropdown.tsx | 2 +- app/src/pages/playground/PlaygroundOutput.tsx | 7 +- .../ModelConfigButtonDialogQuery.graphql.ts | 74 +++++++---- .../ModelPickerFragment.graphql.ts | 56 ++++---- .../ModelProviderPickerFragment.graphql.ts | 64 +++++++++ .../PlaygroundOutputSubscription.graphql.ts | 24 +++- src/phoenix/server/api/queries.py | 123 ++++++++++++------ src/phoenix/server/api/subscriptions.py | 10 +- .../server/api/types/GenerativeProvider.py | 16 +++ src/phoenix/server/api/types/ModelProvider.py | 9 -- 14 files changed, 327 insertions(+), 144 deletions(-) create mode 100644 app/src/pages/playground/__generated__/ModelProviderPickerFragment.graphql.ts create mode 100644 src/phoenix/server/api/types/GenerativeProvider.py delete mode 100644 src/phoenix/server/api/types/ModelProvider.py diff --git a/app/schema.graphql b/app/schema.graphql index f3433d7fde..f1ca6a81cf 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -67,6 +67,7 @@ union Bin = NominalBin | IntervalBin | MissingValueBin input ChatCompletionInput { messages: [ChatCompletionMessageInput!]! + model: GenerativeModelInput! } input ChatCompletionMessageInput { @@ -822,6 +823,22 @@ type Functionality { tracing: Boolean! } +input GenerativeModelInput { + providerKey: GenerativeProviderKey! + name: String! +} + +type GenerativeProvider { + name: String! + key: GenerativeProviderKey! +} + +enum GenerativeProviderKey { + OPENAI + ANTHROPIC + AZURE_OPENAI +} + """ The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID. """ @@ -926,9 +943,8 @@ type Model { ): PerformanceTimeSeries! } -type ModelProvider { - name: String! - modelNames: [String!]! +input ModelNamesInput { + providerKey: GenerativeProviderKey! } type Mutation { @@ -1123,7 +1139,8 @@ type PromptResponse { } type Query { - modelProviders(vendors: [String!]!): [ModelProvider!]! + modelProviders: [GenerativeProvider!]! + modelNames(input: ModelNamesInput!): [String!]! users(first: Int = 50, last: Int, after: String, before: String): UserConnection! userRoles: [UserRole!]! userApiKeys: [UserApiKey!]! diff --git a/app/src/pages/playground/ModelConfigButton.tsx b/app/src/pages/playground/ModelConfigButton.tsx index dd84e028cf..033a929a41 100644 --- a/app/src/pages/playground/ModelConfigButton.tsx +++ b/app/src/pages/playground/ModelConfigButton.tsx @@ -76,14 +76,6 @@ export function ModelConfigButton(props: ModelConfigButtonProps) { interface ModelConfigDialogContentProps extends ModelConfigButtonProps {} function ModelConfigDialogContent(props: ModelConfigDialogContentProps) { - const query = useLazyLoadQuery( - graphql` - query ModelConfigButtonDialogQuery { - ...ModelPickerFragment - } - `, - {} - ); const { playgroundInstanceId } = props; const updateModel = usePlaygroundContext((state) => state.updateModel); const instance = usePlaygroundContext((state) => @@ -94,12 +86,21 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) { `Playground instance ${props.playgroundInstanceId} not found` ); } - + const query = useLazyLoadQuery( + graphql` + query ModelConfigButtonDialogQuery($providerKey: GenerativeProviderKey!) { + ...ModelProviderPickerFragment + ...ModelPickerFragment @arguments(providerKey: $providerKey) + } + `, + { providerKey: instance.model.provider } + ); return (
{ updateModel({ instanceId: playgroundInstanceId, diff --git a/app/src/pages/playground/ModelPicker.tsx b/app/src/pages/playground/ModelPicker.tsx index eef54e2d11..1d9258e803 100644 --- a/app/src/pages/playground/ModelPicker.tsx +++ b/app/src/pages/playground/ModelPicker.tsx @@ -1,4 +1,4 @@ -import React, { useMemo } from "react"; +import React from "react"; import { graphql, useFragment } from "react-relay"; import { Item, Picker, PickerProps } from "@arizeai/components"; @@ -18,22 +18,15 @@ type ModelPickerProps = { export function ModelPicker({ query, onChange, ...props }: ModelPickerProps) { const data = useFragment( graphql` - fragment ModelPickerFragment on Query { - modelProviders(vendors: ["OpenAI", "Anthropic"]) { - name - modelNames - } + fragment ModelPickerFragment on Query + @argumentDefinitions( + providerKey: { type: "GenerativeProviderKey!", defaultValue: OPENAI } + ) { + modelNames(input: { providerKey: $providerKey }) } `, query ); - const modelNames = useMemo(() => { - // TODO: Lowercase is not enough for things like Azure OpenAI - const provider = data.modelProviders.find( - (provider) => provider.name.toLowerCase() === props.provider.toLowerCase() - ); - return provider?.modelNames ?? []; - }, [data, props.provider]); return ( - {modelNames.map((modelName) => { + {data.modelNames.map((modelName) => { return {modelName}; })} diff --git a/app/src/pages/playground/ModelProviderPicker.tsx b/app/src/pages/playground/ModelProviderPicker.tsx index f37470d87e..94a30e34e5 100644 --- a/app/src/pages/playground/ModelProviderPicker.tsx +++ b/app/src/pages/playground/ModelProviderPicker.tsx @@ -1,12 +1,15 @@ import React from "react"; +import { graphql, useFragment } from "react-relay"; import { Item, Picker, PickerProps } from "@arizeai/components"; -import { ModelProviders } from "@phoenix/constants/generativeConstants"; import { isModelProvider } from "@phoenix/utils/generativeUtils"; +import type { ModelProviderPickerFragment$key } from "./__generated__/ModelProviderPickerFragment.graphql"; + type ModelProviderPickerProps = { onChange: (provider: ModelProvider) => void; + query: ModelProviderPickerFragment$key; provider?: ModelProvider; } & Omit< PickerProps, @@ -15,8 +18,20 @@ type ModelProviderPickerProps = { export function ModelProviderPicker({ onChange, + query, ...props }: ModelProviderPickerProps) { + const data = useFragment( + graphql` + fragment ModelProviderPickerFragment on Query { + modelProviders { + key + name + } + } + `, + query + ); return ( - {Object.entries(ModelProviders).map(([key, value]) => { - return {value}; + {data.modelProviders.map((provider) => { + return {provider.name}; })} ); diff --git a/app/src/pages/playground/PlaygroundCredentialsDropdown.tsx b/app/src/pages/playground/PlaygroundCredentialsDropdown.tsx index 3887c19ed5..8b83f19d1f 100644 --- a/app/src/pages/playground/PlaygroundCredentialsDropdown.tsx +++ b/app/src/pages/playground/PlaygroundCredentialsDropdown.tsx @@ -48,7 +48,7 @@ export function PlaygroundCredentialsDropdown() { API Keys - + API keys are stored in your browser and used to communicate with their respective API's. diff --git a/app/src/pages/playground/PlaygroundOutput.tsx b/app/src/pages/playground/PlaygroundOutput.tsx index e5c41c42c8..d957269532 100644 --- a/app/src/pages/playground/PlaygroundOutput.tsx +++ b/app/src/pages/playground/PlaygroundOutput.tsx @@ -103,8 +103,9 @@ function useChatCompletionSubscription({ subscription: graphql` subscription PlaygroundOutputSubscription( $messages: [ChatCompletionMessageInput!]! + $model: GenerativeModelInput! ) { - chatCompletion(input: { messages: $messages }) + chatCompletion(input: { messages: $messages, model: $model }) } `, variables: params, @@ -177,6 +178,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) { useChatCompletionSubscription({ params: { messages: instance.template.messages.map(toGqlChatCompletionMessage), + model: { + providerKey: instance.model.provider, + name: instance.model.modelName || "", + }, }, runId: instance.activeRunId, onNext: (response) => { diff --git a/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts b/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts index f2d7f7a083..187da00815 100644 --- a/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts +++ b/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<<5c8d927475f5ac2ddd05a0c07c6c5b2b>> + * @generated SignedSource<<176456afea57f0245ab80564600db337>> * @lightSyntaxTransform * @nogrep */ @@ -10,18 +10,36 @@ import { ConcreteRequest, Query } from 'relay-runtime'; import { FragmentRefs } from "relay-runtime"; -export type ModelConfigButtonDialogQuery$variables = Record; +export type GenerativeProviderKey = "ANTHROPIC" | "AZURE_OPENAI" | "OPENAI"; +export type ModelConfigButtonDialogQuery$variables = { + providerKey: GenerativeProviderKey; +}; export type ModelConfigButtonDialogQuery$data = { - readonly " $fragmentSpreads": FragmentRefs<"ModelPickerFragment">; + readonly " $fragmentSpreads": FragmentRefs<"ModelPickerFragment" | "ModelProviderPickerFragment">; }; export type ModelConfigButtonDialogQuery = { response: ModelConfigButtonDialogQuery$data; variables: ModelConfigButtonDialogQuery$variables; }; -const node: ConcreteRequest = { +const node: ConcreteRequest = (function(){ +var v0 = [ + { + "defaultValue": null, + "kind": "LocalArgument", + "name": "providerKey" + } +], +v1 = [ + { + "kind": "Variable", + "name": "providerKey", + "variableName": "providerKey" + } +]; +return { "fragment": { - "argumentDefinitions": [], + "argumentDefinitions": (v0/*: any*/), "kind": "Fragment", "metadata": null, "name": "ModelConfigButtonDialogQuery", @@ -29,6 +47,11 @@ const node: ConcreteRequest = { { "args": null, "kind": "FragmentSpread", + "name": "ModelProviderPickerFragment" + }, + { + "args": (v1/*: any*/), + "kind": "FragmentSpread", "name": "ModelPickerFragment" } ], @@ -37,23 +60,14 @@ const node: ConcreteRequest = { }, "kind": "Request", "operation": { - "argumentDefinitions": [], + "argumentDefinitions": (v0/*: any*/), "kind": "Operation", "name": "ModelConfigButtonDialogQuery", "selections": [ { "alias": null, - "args": [ - { - "kind": "Literal", - "name": "vendors", - "value": [ - "OpenAI", - "Anthropic" - ] - } - ], - "concreteType": "ModelProvider", + "args": null, + "concreteType": "GenerativeProvider", "kind": "LinkedField", "name": "modelProviders", "plural": true, @@ -62,31 +76,45 @@ const node: ConcreteRequest = { "alias": null, "args": null, "kind": "ScalarField", - "name": "name", + "name": "key", "storageKey": null }, { "alias": null, "args": null, "kind": "ScalarField", - "name": "modelNames", + "name": "name", "storageKey": null } ], - "storageKey": "modelProviders(vendors:[\"OpenAI\",\"Anthropic\"])" + "storageKey": null + }, + { + "alias": null, + "args": [ + { + "fields": (v1/*: any*/), + "kind": "ObjectValue", + "name": "input" + } + ], + "kind": "ScalarField", + "name": "modelNames", + "storageKey": null } ] }, "params": { - "cacheID": "cf6e53f2293c51d168b08e5d2fc7b391", + "cacheID": "34f8d81e91b335ca310c9be756719426", "id": null, "metadata": {}, "name": "ModelConfigButtonDialogQuery", "operationKind": "query", - "text": "query ModelConfigButtonDialogQuery {\n ...ModelPickerFragment\n}\n\nfragment ModelPickerFragment on Query {\n modelProviders(vendors: [\"OpenAI\", \"Anthropic\"]) {\n name\n modelNames\n }\n}\n" + "text": "query ModelConfigButtonDialogQuery(\n $providerKey: GenerativeProviderKey!\n) {\n ...ModelProviderPickerFragment\n ...ModelPickerFragment_3rERSq\n}\n\nfragment ModelPickerFragment_3rERSq on Query {\n modelNames(input: {providerKey: $providerKey})\n}\n\nfragment ModelProviderPickerFragment on Query {\n modelProviders {\n key\n name\n }\n}\n" } }; +})(); -(node as any).hash = "98dc1b22aa68897365be9b248625fc1d"; +(node as any).hash = "c9b38e766093b2378047d22b01ef0fbf"; export default node; diff --git a/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts b/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts index e6d3242e38..9dda921ba2 100644 --- a/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts +++ b/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<<3ea2d3c123a67eedf3cfe075d801e5e2>> + * @generated SignedSource<<6931dc528aea2b22801320e6d297dd58>> * @lightSyntaxTransform * @nogrep */ @@ -11,10 +11,7 @@ import { Fragment, ReaderFragment } from 'relay-runtime'; import { FragmentRefs } from "relay-runtime"; export type ModelPickerFragment$data = { - readonly modelProviders: ReadonlyArray<{ - readonly modelNames: ReadonlyArray; - readonly name: string; - }>; + readonly modelNames: ReadonlyArray; readonly " $fragmentType": "ModelPickerFragment"; }; export type ModelPickerFragment$key = { @@ -23,7 +20,13 @@ export type ModelPickerFragment$key = { }; const node: ReaderFragment = { - "argumentDefinitions": [], + "argumentDefinitions": [ + { + "defaultValue": "OPENAI", + "kind": "LocalArgument", + "name": "providerKey" + } + ], "kind": "Fragment", "metadata": null, "name": "ModelPickerFragment", @@ -32,41 +35,26 @@ const node: ReaderFragment = { "alias": null, "args": [ { - "kind": "Literal", - "name": "vendors", - "value": [ - "OpenAI", - "Anthropic" - ] - } - ], - "concreteType": "ModelProvider", - "kind": "LinkedField", - "name": "modelProviders", - "plural": true, - "selections": [ - { - "alias": null, - "args": null, - "kind": "ScalarField", - "name": "name", - "storageKey": null - }, - { - "alias": null, - "args": null, - "kind": "ScalarField", - "name": "modelNames", - "storageKey": null + "fields": [ + { + "kind": "Variable", + "name": "providerKey", + "variableName": "providerKey" + } + ], + "kind": "ObjectValue", + "name": "input" } ], - "storageKey": "modelProviders(vendors:[\"OpenAI\",\"Anthropic\"])" + "kind": "ScalarField", + "name": "modelNames", + "storageKey": null } ], "type": "Query", "abstractKey": null }; -(node as any).hash = "0c369d2705c164e6d3bf698b9fdaa934"; +(node as any).hash = "bb2557396c978bb5f57c7a4f67d756b1"; export default node; diff --git a/app/src/pages/playground/__generated__/ModelProviderPickerFragment.graphql.ts b/app/src/pages/playground/__generated__/ModelProviderPickerFragment.graphql.ts new file mode 100644 index 0000000000..fd487f655f --- /dev/null +++ b/app/src/pages/playground/__generated__/ModelProviderPickerFragment.graphql.ts @@ -0,0 +1,64 @@ +/** + * @generated SignedSource<<8d3d09b89a6d54cc8b22d75946b7094b>> + * @lightSyntaxTransform + * @nogrep + */ + +/* tslint:disable */ +/* eslint-disable */ +// @ts-nocheck + +import { Fragment, ReaderFragment } from 'relay-runtime'; +export type GenerativeProviderKey = "ANTHROPIC" | "AZURE_OPENAI" | "OPENAI"; +import { FragmentRefs } from "relay-runtime"; +export type ModelProviderPickerFragment$data = { + readonly modelProviders: ReadonlyArray<{ + readonly key: GenerativeProviderKey; + readonly name: string; + }>; + readonly " $fragmentType": "ModelProviderPickerFragment"; +}; +export type ModelProviderPickerFragment$key = { + readonly " $data"?: ModelProviderPickerFragment$data; + readonly " $fragmentSpreads": FragmentRefs<"ModelProviderPickerFragment">; +}; + +const node: ReaderFragment = { + "argumentDefinitions": [], + "kind": "Fragment", + "metadata": null, + "name": "ModelProviderPickerFragment", + "selections": [ + { + "alias": null, + "args": null, + "concreteType": "GenerativeProvider", + "kind": "LinkedField", + "name": "modelProviders", + "plural": true, + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "key", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "name", + "storageKey": null + } + ], + "storageKey": null + } + ], + "type": "Query", + "abstractKey": null +}; + +(node as any).hash = "c83e86a2772127916f7387dca27b74ce"; + +export default node; diff --git a/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts b/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts index ed48792e68..1e51c5f304 100644 --- a/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts +++ b/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<<767976775ee226eb849cf909b42f2897>> + * @generated SignedSource<> * @lightSyntaxTransform * @nogrep */ @@ -10,12 +10,18 @@ import { ConcreteRequest, GraphQLSubscription } from 'relay-runtime'; export type ChatCompletionMessageRole = "AI" | "SYSTEM" | "TOOL" | "USER"; +export type GenerativeProviderKey = "ANTHROPIC" | "AZURE_OPENAI" | "OPENAI"; export type ChatCompletionMessageInput = { content: any; role: ChatCompletionMessageRole; }; +export type GenerativeModelInput = { + name: string; + providerKey: GenerativeProviderKey; +}; export type PlaygroundOutputSubscription$variables = { messages: ReadonlyArray; + model: GenerativeModelInput; }; export type PlaygroundOutputSubscription$data = { readonly chatCompletion: string; @@ -31,6 +37,11 @@ var v0 = [ "defaultValue": null, "kind": "LocalArgument", "name": "messages" + }, + { + "defaultValue": null, + "kind": "LocalArgument", + "name": "model" } ], v1 = [ @@ -43,6 +54,11 @@ v1 = [ "kind": "Variable", "name": "messages", "variableName": "messages" + }, + { + "kind": "Variable", + "name": "model", + "variableName": "model" } ], "kind": "ObjectValue", @@ -72,16 +88,16 @@ return { "selections": (v1/*: any*/) }, "params": { - "cacheID": "0856059e2e3a28a8fda74106924c480d", + "cacheID": "441178357c664007fd0a3713b565dffd", "id": null, "metadata": {}, "name": "PlaygroundOutputSubscription", "operationKind": "subscription", - "text": "subscription PlaygroundOutputSubscription(\n $messages: [ChatCompletionMessageInput!]!\n) {\n chatCompletion(input: {messages: $messages})\n}\n" + "text": "subscription PlaygroundOutputSubscription(\n $messages: [ChatCompletionMessageInput!]!\n $model: GenerativeModelInput!\n) {\n chatCompletion(input: {messages: $messages, model: $model})\n}\n" } }; })(); -(node as any).hash = "d760af0f8e18301631ba9dcdb148de0b"; +(node as any).hash = "7dc12b37e3f80c94a3ca91d20b3292a7"; export default node; diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py index 6ca3e5c848..4c8319555e 100644 --- a/src/phoenix/server/api/queries.py +++ b/src/phoenix/server/api/queries.py @@ -11,7 +11,7 @@ from strawberry import ID, UNSET from strawberry.relay import Connection, GlobalID, Node from strawberry.types import Info -from typing_extensions import Annotated, TypeAlias +from typing_extensions import Annotated, TypeAlias, assert_never from phoenix.db import enums, models from phoenix.db.models import ( @@ -58,9 +58,12 @@ from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run from phoenix.server.api.types.Functionality import Functionality +from phoenix.server.api.types.GenerativeProvider import ( + GenerativeProvider, + GenerativeProviderKey, +) from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole from phoenix.server.api.types.Model import Model -from phoenix.server.api.types.ModelProvider import ModelProvider from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type from phoenix.server.api.types.pagination import ( ConnectionArgs, @@ -77,50 +80,88 @@ from phoenix.server.api.types.UserRole import UserRole +@strawberry.input +class ModelNamesInput: + provider_key: GenerativeProviderKey + + @strawberry.type class Query: @strawberry.field - async def model_providers( - self, vendors: List[str], info: Info[Context, None] - ) -> List[ModelProvider]: - all_vendors = { - "OpenAI": ModelProvider( # https://platform.openai.com/docs/models - name="OpenAI", # currently only models using the chat completions API - model_names=[ - "o1-preview", - "o1-preview-2024-09-12", - "o1-mini", - "o1-mini-2024-09-12", - "gpt-4o", - "gpt-4o-2024-08-06", - "gpt-4o-2024-05-13", - "chatgpt-4o-latest", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-4-turbo-preview", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-4", - "gpt-4-0613", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-instruct", - ], + async def model_providers(self) -> List[GenerativeProvider]: + return [ + GenerativeProvider( + name="OpenAI", + key=GenerativeProviderKey.OPENAI, ), - "Anthropic": ModelProvider( # https://docs.anthropic.com/en/docs/about-claude/models#model-comparison - name="Anthropic", # currently only models using the messages API - model_names=[ - "claude-3-5-sonnet-20240620", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - ], + GenerativeProvider( + name="Azure OpenAI", + key=GenerativeProviderKey.AZURE_OPENAI, ), - } - return [all_vendors[vendor] for vendor in vendors] + GenerativeProvider( + name="Anthropic", + key=GenerativeProviderKey.ANTHROPIC, + ), + ] + + @strawberry.field + async def model_names(self, input: ModelNamesInput) -> List[str]: + if (provider_key := input.provider_key) == GenerativeProviderKey.OPENAI: + return [ + "o1-preview", + "o1-preview-2024-09-12", + "o1-mini", + "o1-mini-2024-09-12", + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "chatgpt-4o-latest", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-4", + "gpt-4-0613", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-instruct", + ] + if provider_key == GenerativeProviderKey.AZURE_OPENAI: + return [ + "o1-preview", + "o1-preview-2024-09-12", + "o1-mini", + "o1-mini-2024-09-12", + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "chatgpt-4o-latest", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-4", + "gpt-4-0613", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-instruct", + ] + if provider_key == GenerativeProviderKey.ANTHROPIC: + return [ + "claude-3-5-sonnet-20240620", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + ] + assert_never(provider_key) @strawberry.field(permission_classes=[IsAdmin]) # type: ignore async def users( diff --git a/src/phoenix/server/api/subscriptions.py b/src/phoenix/server/api/subscriptions.py index 942fa12078..5842431a93 100644 --- a/src/phoenix/server/api/subscriptions.py +++ b/src/phoenix/server/api/subscriptions.py @@ -26,6 +26,7 @@ from phoenix.server.api.context import Context from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole +from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey from phoenix.server.dml_event import SpanInsertEvent from phoenix.trace.attributes import unflatten @@ -37,9 +38,16 @@ PLAYGROUND_PROJECT_NAME = "playground" +@strawberry.input +class GenerativeModelInput: + provider_key: GenerativeProviderKey + name: str + + @strawberry.input class ChatCompletionInput: messages: List[ChatCompletionMessageInput] + model: GenerativeModelInput def to_openai_chat_completion_param( @@ -109,7 +117,7 @@ async def chat_completion( role: Optional[str] = None async for chunk in await client.chat.completions.create( messages=(to_openai_chat_completion_param(message) for message in input.messages), - model="gpt-4", + model=input.model.name, stream=True, ): chunks.append(chunk) diff --git a/src/phoenix/server/api/types/GenerativeProvider.py b/src/phoenix/server/api/types/GenerativeProvider.py new file mode 100644 index 0000000000..a9a41ce5ac --- /dev/null +++ b/src/phoenix/server/api/types/GenerativeProvider.py @@ -0,0 +1,16 @@ +from enum import Enum + +import strawberry + + +@strawberry.enum +class GenerativeProviderKey(Enum): + OPENAI = "OPENAI" + ANTHROPIC = "ANTHROPIC" + AZURE_OPENAI = "AZURE_OPENAI" + + +@strawberry.type +class GenerativeProvider: + name: str + key: GenerativeProviderKey diff --git a/src/phoenix/server/api/types/ModelProvider.py b/src/phoenix/server/api/types/ModelProvider.py deleted file mode 100644 index 680bc3aaaa..0000000000 --- a/src/phoenix/server/api/types/ModelProvider.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import List - -import strawberry - - -@strawberry.type -class ModelProvider: - name: str - model_names: List[str]