Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(playground): model selector #4971

Merged
merged 4 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/src/@types/generative.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
declare type ModelProvider = "OPENAI" | "AZURE_OPENAI" | "ANTHROPIC";
8 changes: 8 additions & 0 deletions app/src/constants/generativeConstants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/**
* A mapping of ModelProvider to a human-readable string
*/
export const ModelProviders: Record<ModelProvider, string> = {
OPENAI: "OpenAI",
AZURE_OPENAI: "Azure OpenAI",
ANTHROPIC: "Anthropic",
};
130 changes: 130 additions & 0 deletions app/src/pages/playground/ModelConfigButton.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import React, {
Fragment,
ReactNode,
startTransition,
Suspense,
useState,
} from "react";
import { graphql, useLazyLoadQuery } from "react-relay";

import {
Button,
Dialog,
DialogContainer,
Flex,
Form,
Text,
View,
} from "@arizeai/components";

import { ModelProviders } from "@phoenix/constants/generativeConstants";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";

import { ModelConfigButtonDialogQuery } from "./__generated__/ModelConfigButtonDialogQuery.graphql";
import { ModelPicker } from "./ModelPicker";
import { ModelProviderPicker } from "./ModelProviderPicker";
import { PlaygroundInstanceProps } from "./types";

interface ModelConfigButtonProps extends PlaygroundInstanceProps {}
export function ModelConfigButton(props: ModelConfigButtonProps) {
const [dialog, setDialog] = useState<ReactNode>(null);
const instance = usePlaygroundContext((state) =>
state.instances.find(
(instance) => instance.id === props.playgroundInstanceId
)
);

if (!instance) {
throw new Error(
`Playground instance ${props.playgroundInstanceId} not found`
);
}
return (
<Fragment>
<Button
variant="default"
size="compact"
onClick={() => {
startTransition(() => {
setDialog(
<Dialog title="Model Configuration" size="M">
<Suspense>
<ModelConfigDialogContent {...props} />
</Suspense>
</Dialog>
);
});
}}
>
<Flex direction="row" gap="size-100" alignItems="center">
<Text weight="heavy">{ModelProviders[instance.model.provider]}</Text>
<Text>{instance.model.modelName || "--"}</Text>
</Flex>
</Button>
<DialogContainer
type="slideOver"
isDismissable
onDismiss={() => {
setDialog(null);
}}
>
{dialog}
</DialogContainer>
</Fragment>
);
}

interface ModelConfigDialogContentProps extends ModelConfigButtonProps {}
function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
const query = useLazyLoadQuery<ModelConfigButtonDialogQuery>(
graphql`
query ModelConfigButtonDialogQuery {
...ModelPickerFragment
}
`,
{}
);
const { playgroundInstanceId } = props;
const updateModel = usePlaygroundContext((state) => state.updateModel);
const instance = usePlaygroundContext((state) =>
state.instances.find((instance) => instance.id === playgroundInstanceId)
);
if (!instance) {
throw new Error(
`Playground instance ${props.playgroundInstanceId} not found`
);
}

return (
<View padding="size-200">
<Form>
<ModelProviderPicker
provider={instance.model.provider}
onChange={(provider) => {
updateModel({
instanceId: playgroundInstanceId,
model: {
provider,
modelName: null,
},
});
}}
/>
<ModelPicker
modelName={instance.model.modelName}
provider={instance.model.provider}
query={query}
onChange={(modelName) => {
updateModel({
instanceId: playgroundInstanceId,
model: {
provider: instance.model.provider,
modelName,
},
});
}}
/>
</Form>
</View>
);
}
57 changes: 57 additions & 0 deletions app/src/pages/playground/ModelPicker.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import React, { useMemo } from "react";
import { graphql, useFragment } from "react-relay";

import { Item, Picker, PickerProps } from "@arizeai/components";

import { ModelPickerFragment$key } from "./__generated__/ModelPickerFragment.graphql";

type ModelPickerProps = {
query: ModelPickerFragment$key;
onChange: (model: string) => void;
provider: ModelProvider;
modelName: string | null;
} & Omit<
PickerProps<string>,
"children" | "onSelectionChange" | "defaultSelectedKey"
>;

export function ModelPicker({ query, onChange, ...props }: ModelPickerProps) {
const data = useFragment<ModelPickerFragment$key>(
graphql`
fragment ModelPickerFragment on Query {
modelProviders(vendors: ["OpenAI", "Anthropic"]) {
name
modelNames
}
}
`,
query
);
const modelNames = useMemo(() => {
// TODO: Lowercase is not enough for things like Azure OpenAI
const provider = data.modelProviders.find(
(provider) => provider.name.toLowerCase() === props.provider.toLowerCase()
);
return provider?.modelNames ?? [];
}, [data, props.provider]);
return (
<Picker
label={"Model"}
data-testid="model-picker"
selectedKey={props.modelName ?? undefined}
aria-label="model picker"
placeholder="Select a model"
onSelectionChange={(key) => {
if (typeof key === "string") {
onChange(key);
}
}}
width={"100%"}
{...props}
>
{modelNames.map((modelName) => {
return <Item key={modelName}>{modelName}</Item>;
})}
</Picker>
);
}
41 changes: 41 additions & 0 deletions app/src/pages/playground/ModelProviderPicker.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import React from "react";

import { Item, Picker, PickerProps } from "@arizeai/components";

import { ModelProviders } from "@phoenix/constants/generativeConstants";
import { isModelProvider } from "@phoenix/utils/generativeUtils";

type ModelProviderPickerProps = {
onChange: (provider: ModelProvider) => void;
provider?: ModelProvider;
} & Omit<
PickerProps<ModelProvider>,
"children" | "onSelectionChange" | "defaultSelectedKey"
>;

export function ModelProviderPicker({
onChange,
...props
}: ModelProviderPickerProps) {
return (
<Picker
label={"Provider"}
data-testid="model-provider-picker"
selectedKey={props.provider ?? undefined}
aria-label="Model Provider"
placeholder="Select a provider"
onSelectionChange={(key) => {
const provider = key as string;
if (isModelProvider(provider)) {
onChange(provider);
}
}}
width={"100%"}
{...props}
>
{Object.entries(ModelProviders).map(([key, value]) => {
return <Item key={key}>{value}</Item>;
})}
</Picker>
);
}
8 changes: 7 additions & 1 deletion app/src/pages/playground/PlaygroundTemplate.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
import { AlphabeticIndexIcon } from "@phoenix/components/AlphabeticIndexIcon";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";

import { ModelConfigButton } from "./ModelConfigButton";
import { PlaygroundChatTemplate } from "./PlaygroundChatTemplate";
import { PlaygroundInstanceProps } from "./types";

Expand All @@ -41,7 +42,12 @@ export function PlaygroundTemplate(props: PlaygroundTemplateProps) {
collapsible
variant="compact"
bodyStyle={{ padding: 0 }}
extra={instances.length > 1 ? <DeleteButton {...props} /> : null}
extra={
<Flex direction="row" gap="size-100">
<ModelConfigButton {...props} />
{instances.length > 1 ? <DeleteButton {...props} /> : null}
</Flex>
}
>
{template.__type === "chat" ? (
<PlaygroundChatTemplate {...props} />
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading