Skip to content

Commit

Permalink
feat(playground): model selector (#4971)
Browse files Browse the repository at this point in the history
* global declaration of generative

* feat(playground): rudamentary model selector

* model provider in tests

* WIP
  • Loading branch information
mikeldking authored Oct 11, 2024
1 parent 0d1c141 commit 025c33e
Show file tree
Hide file tree
Showing 11 changed files with 466 additions and 1 deletion.
1 change: 1 addition & 0 deletions app/src/@types/generative.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
declare type ModelProvider = "OPENAI" | "AZURE_OPENAI" | "ANTHROPIC";
8 changes: 8 additions & 0 deletions app/src/constants/generativeConstants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/**
* A mapping of ModelProvider to a human-readable string
*/
export const ModelProviders: Record<ModelProvider, string> = {
OPENAI: "OpenAI",
AZURE_OPENAI: "Azure OpenAI",
ANTHROPIC: "Anthropic",
};
130 changes: 130 additions & 0 deletions app/src/pages/playground/ModelConfigButton.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import React, {
Fragment,
ReactNode,
startTransition,
Suspense,
useState,
} from "react";
import { graphql, useLazyLoadQuery } from "react-relay";

import {
Button,
Dialog,
DialogContainer,
Flex,
Form,
Text,
View,
} from "@arizeai/components";

import { ModelProviders } from "@phoenix/constants/generativeConstants";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";

import { ModelConfigButtonDialogQuery } from "./__generated__/ModelConfigButtonDialogQuery.graphql";
import { ModelPicker } from "./ModelPicker";
import { ModelProviderPicker } from "./ModelProviderPicker";
import { PlaygroundInstanceProps } from "./types";

interface ModelConfigButtonProps extends PlaygroundInstanceProps {}
export function ModelConfigButton(props: ModelConfigButtonProps) {
const [dialog, setDialog] = useState<ReactNode>(null);
const instance = usePlaygroundContext((state) =>
state.instances.find(
(instance) => instance.id === props.playgroundInstanceId
)
);

if (!instance) {
throw new Error(
`Playground instance ${props.playgroundInstanceId} not found`
);
}
return (
<Fragment>
<Button
variant="default"
size="compact"
onClick={() => {
startTransition(() => {
setDialog(
<Dialog title="Model Configuration" size="M">
<Suspense>
<ModelConfigDialogContent {...props} />
</Suspense>
</Dialog>
);
});
}}
>
<Flex direction="row" gap="size-100" alignItems="center">
<Text weight="heavy">{ModelProviders[instance.model.provider]}</Text>
<Text>{instance.model.modelName || "--"}</Text>
</Flex>
</Button>
<DialogContainer
type="slideOver"
isDismissable
onDismiss={() => {
setDialog(null);
}}
>
{dialog}
</DialogContainer>
</Fragment>
);
}

interface ModelConfigDialogContentProps extends ModelConfigButtonProps {}
function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
const query = useLazyLoadQuery<ModelConfigButtonDialogQuery>(
graphql`
query ModelConfigButtonDialogQuery {
...ModelPickerFragment
}
`,
{}
);
const { playgroundInstanceId } = props;
const updateModel = usePlaygroundContext((state) => state.updateModel);
const instance = usePlaygroundContext((state) =>
state.instances.find((instance) => instance.id === playgroundInstanceId)
);
if (!instance) {
throw new Error(
`Playground instance ${props.playgroundInstanceId} not found`
);
}

return (
<View padding="size-200">
<Form>
<ModelProviderPicker
provider={instance.model.provider}
onChange={(provider) => {
updateModel({
instanceId: playgroundInstanceId,
model: {
provider,
modelName: null,
},
});
}}
/>
<ModelPicker
modelName={instance.model.modelName}
provider={instance.model.provider}
query={query}
onChange={(modelName) => {
updateModel({
instanceId: playgroundInstanceId,
model: {
provider: instance.model.provider,
modelName,
},
});
}}
/>
</Form>
</View>
);
}
57 changes: 57 additions & 0 deletions app/src/pages/playground/ModelPicker.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import React, { useMemo } from "react";
import { graphql, useFragment } from "react-relay";

import { Item, Picker, PickerProps } from "@arizeai/components";

import { ModelPickerFragment$key } from "./__generated__/ModelPickerFragment.graphql";

type ModelPickerProps = {
query: ModelPickerFragment$key;
onChange: (model: string) => void;
provider: ModelProvider;
modelName: string | null;
} & Omit<
PickerProps<string>,
"children" | "onSelectionChange" | "defaultSelectedKey"
>;

export function ModelPicker({ query, onChange, ...props }: ModelPickerProps) {
const data = useFragment<ModelPickerFragment$key>(
graphql`
fragment ModelPickerFragment on Query {
modelProviders(vendors: ["OpenAI", "Anthropic"]) {
name
modelNames
}
}
`,
query
);
const modelNames = useMemo(() => {
// TODO: Lowercase is not enough for things like Azure OpenAI
const provider = data.modelProviders.find(
(provider) => provider.name.toLowerCase() === props.provider.toLowerCase()
);
return provider?.modelNames ?? [];
}, [data, props.provider]);
return (
<Picker
label={"Model"}
data-testid="model-picker"
selectedKey={props.modelName ?? undefined}
aria-label="model picker"
placeholder="Select a model"
onSelectionChange={(key) => {
if (typeof key === "string") {
onChange(key);
}
}}
width={"100%"}
{...props}
>
{modelNames.map((modelName) => {
return <Item key={modelName}>{modelName}</Item>;
})}
</Picker>
);
}
41 changes: 41 additions & 0 deletions app/src/pages/playground/ModelProviderPicker.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import React from "react";

import { Item, Picker, PickerProps } from "@arizeai/components";

import { ModelProviders } from "@phoenix/constants/generativeConstants";
import { isModelProvider } from "@phoenix/utils/generativeUtils";

type ModelProviderPickerProps = {
onChange: (provider: ModelProvider) => void;
provider?: ModelProvider;
} & Omit<
PickerProps<ModelProvider>,
"children" | "onSelectionChange" | "defaultSelectedKey"
>;

export function ModelProviderPicker({
onChange,
...props
}: ModelProviderPickerProps) {
return (
<Picker
label={"Provider"}
data-testid="model-provider-picker"
selectedKey={props.provider ?? undefined}
aria-label="Model Provider"
placeholder="Select a provider"
onSelectionChange={(key) => {
const provider = key as string;
if (isModelProvider(provider)) {
onChange(provider);
}
}}
width={"100%"}
{...props}
>
{Object.entries(ModelProviders).map(([key, value]) => {
return <Item key={key}>{value}</Item>;
})}
</Picker>
);
}
8 changes: 7 additions & 1 deletion app/src/pages/playground/PlaygroundTemplate.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
import { AlphabeticIndexIcon } from "@phoenix/components/AlphabeticIndexIcon";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";

import { ModelConfigButton } from "./ModelConfigButton";
import { PlaygroundChatTemplate } from "./PlaygroundChatTemplate";
import { PlaygroundInstanceProps } from "./types";

Expand All @@ -41,7 +42,12 @@ export function PlaygroundTemplate(props: PlaygroundTemplateProps) {
collapsible
variant="compact"
bodyStyle={{ padding: 0 }}
extra={instances.length > 1 ? <DeleteButton {...props} /> : null}
extra={
<Flex direction="row" gap="size-100">
<ModelConfigButton {...props} />
{instances.length > 1 ? <DeleteButton {...props} /> : null}
</Flex>
}
>
{template.__type === "chat" ? (
<PlaygroundChatTemplate {...props} />
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 025c33e

Please sign in to comment.