Skip to content

Commit

Permalink
feat: Add Google AI Studio support (for Gemini models w/ an API key) (#…
Browse files Browse the repository at this point in the history
…5359)

* Spike out OpenAIStudio Gemini provider

* Add provider

* Remove rate limiter

* Update schema

* Update float invocation parameter instantiation

* Remove nonexistent canonical parameter name

* Use more robust dependency type

* Update dependencies resolver

* Properly access dictionary

* Use sync streaming for gemini

* Don't use async context manager

* Don't use context manager for stream

* Support gemini provider in UI

See #5348 for implementing tool calling support

* Small fix to type annotations

* Resolve rebase issues

* No need for type dicts here

* Minimize newlines in system prompt

---------

Co-authored-by: Tony Powell <apowell@arize.com>
  • Loading branch information
anticorrelator and cephalization authored Nov 14, 2024
1 parent 9a6d16a commit cf9e7f4
Show file tree
Hide file tree
Showing 24 changed files with 258 additions and 39 deletions.
1 change: 1 addition & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ enum GenerativeProviderKey {
OPENAI
ANTHROPIC
AZURE_OPENAI
GEMINI
}

"""
Expand Down
3 changes: 2 additions & 1 deletion app/src/@types/generative.d.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
declare type ModelProvider = "OPENAI" | "AZURE_OPENAI" | "ANTHROPIC";
// TODO: Pull from GenerativeProviderKey in gql schema
declare type ModelProvider = "OPENAI" | "AZURE_OPENAI" | "ANTHROPIC" | "GEMINI";

/**
* The role of a chat message
Expand Down
1 change: 1 addition & 0 deletions app/src/constants/generativeConstants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export const ModelProviders: Record<ModelProvider, string> = {
OPENAI: "OpenAI",
AZURE_OPENAI: "Azure OpenAI",
ANTHROPIC: "Anthropic",
GEMINI: "Gemini",
};

/**
Expand Down
3 changes: 3 additions & 0 deletions app/src/pages/playground/ChatMessageToolCallsEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ export function ChatMessageToolCallsEditor({
return openAIToolCallsJSONSchema as JSONSchema7;
case "ANTHROPIC":
return anthropicToolCallsJSONSchema as JSONSchema7;
// TODO(apowell): #5348 Add Gemini tool calls schema
case "GEMINI":
return openAIToolCallsJSONSchema as JSONSchema7;
}
}, [instance.model.provider]);

Expand Down
6 changes: 5 additions & 1 deletion app/src/pages/playground/InvocationParametersForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ const InvocationParameterFormField = ({
switch (__typename) {
case "InvocationParameterBase":
return null;
case "FloatInvocationParameter":
case "BoundedFloatInvocationParameter":
if (typeof value !== "number" && value !== undefined) return null;
return (
Expand All @@ -57,6 +56,7 @@ const InvocationParameterFormField = ({
onChange={(value) => onChange(value)}
/>
);
case "FloatInvocationParameter":
case "IntInvocationParameter":
return (
<TextField
Expand Down Expand Up @@ -206,6 +206,10 @@ export const InvocationParametersForm = ({
invocationInputField
floatDefaultValue: defaultValue
}
... on FloatInvocationParameter {
invocationInputField
floatDefaultValue: defaultValue
}
... on IntInvocationParameter {
invocationInputField
intDefaultValue: defaultValue
Expand Down
1 change: 1 addition & 0 deletions app/src/pages/playground/PlaygroundCredentialsDropdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export const ProviderToCredentialNameMap: Record<ModelProvider, string> = {
OPENAI: "OPENAI_API_KEY",
ANTHROPIC: "ANTHROPIC_API_KEY",
AZURE_OPENAI: "AZURE_OPENAI_API_KEY",
GEMINI: "GEMINI_API_KEY",
};

export function PlaygroundCredentialsDropdown() {
Expand Down
2 changes: 2 additions & 0 deletions app/src/pages/playground/PlaygroundTool.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ export function PlaygroundTool({
return openAIToolDefinitionJSONSchema as JSONSchema7;
case "ANTHROPIC":
return anthropicToolDefinitionJSONSchema as JSONSchema7;
case "GEMINI":
return openAIToolDefinitionJSONSchema as JSONSchema7;
}
}, [instance.model.provider]);

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions app/src/pages/playground/__tests__/playgroundUtils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ describe("processAttributeToolCalls", () => {
testSpanToolCall,
expectedTestOpenAIToolCall,
],
// TODO(apowell): #5348 Add Gemini tool tests
GEMINI: ["GEMINI", testSpanToolCall, expectedTestOpenAIToolCall],
};
test.for(Object.values(ProviderToToolCallTestMap))(
"should return %s tools, if they are valid",
Expand Down Expand Up @@ -1118,6 +1120,8 @@ describe("getToolsFromAttributes", () => {
testSpanOpenAITool,
testSpanOpenAIToolJsonSchema,
],
// TODO(apowell): #5348 Add Gemini tool tests
GEMINI: ["GEMINI", testSpanOpenAITool, testSpanOpenAIToolJsonSchema],
};

test.for(Object.values(ProviderToToolTestMap))(
Expand Down
1 change: 1 addition & 0 deletions app/src/pages/playground/constants.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export const modelProviderToModelPrefixMap: Record<ModelProvider, string[]> = {
AZURE_OPENAI: [],
ANTHROPIC: ["claude"],
OPENAI: ["gpt", "o1"],
GEMINI: ["gemini"],
};

export const TOOL_CHOICE_PARAM_CANONICAL_NAME: Extract<
Expand Down
27 changes: 27 additions & 0 deletions app/src/pages/playground/playgroundUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ export function processAttributeToolCalls({
input: toolCallArgs,
};
}
// TODO(apowell): #5348 Add Gemini tool call
case "GEMINI":
return {
id: tool_call.id ?? "",
function: {
name: tool_call.function?.name ?? "",
arguments: toolCallArgs,
},
};
default:
assertUnreachable(provider);
}
Expand Down Expand Up @@ -699,6 +708,12 @@ export const convertInstanceToolsToProvider = ({
targetProvider: provider,
}),
};
// TODO(apowell): #5348 Add Gemini tool definition
case "GEMINI":
return {
...tool,
definition: toOpenAIToolDefinition(tool.definition),
};
default:
assertUnreachable(provider);
}
Expand All @@ -725,6 +740,9 @@ export const convertMessageToolCallsToProvider = ({
toolCall: toOpenAIToolCall(toolCall),
targetProvider: provider,
});
// TODO(apowell): #5348 Add Gemini tool call
case "GEMINI":
return toOpenAIToolCall(toolCall);
default:
assertUnreachable(provider);
}
Expand Down Expand Up @@ -756,6 +774,12 @@ export const createToolForProvider = ({
id: generateToolId(),
definition: createAnthropicToolDefinition(toolNumber),
};
// TODO(apowell): #5348 Add Gemini tool definition
case "GEMINI":
return {
id: generateToolId(),
definition: createOpenAIToolDefinition(toolNumber),
};
default:
assertUnreachable(provider);
}
Expand All @@ -775,6 +799,9 @@ export const createToolCallForProvider = (
return createOpenAIToolCall();
case "ANTHROPIC":
return createAnthropicToolCall();
// TODO(apowell): #5348 Add Gemini tool call
case "GEMINI":
return createOpenAIToolCall();
default:
assertUnreachable(provider);
}
Expand Down
3 changes: 3 additions & 0 deletions app/src/schemas/toolCallSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ type ProviderToToolCallMap = {
OPENAI: OpenAIToolCall;
AZURE_OPENAI: OpenAIToolCall;
ANTHROPIC: AnthropicToolCall;
GEMINI: OpenAIToolCall;
};

/**
Expand Down Expand Up @@ -230,6 +231,8 @@ export const fromOpenAIToolCall = <T extends ModelProvider>({
return openAIToolCallToAnthropic.parse(
toolCall
) as ProviderToToolCallMap[T];
case "GEMINI":
return toolCall as ProviderToToolCallMap[T];
default:
assertUnreachable(targetProvider);
}
Expand Down
5 changes: 5 additions & 0 deletions app/src/schemas/toolChoiceSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ type ProviderToToolChoiceMap = {
OPENAI: OpenaiToolChoice;
AZURE_OPENAI: OpenaiToolChoice;
ANTHROPIC: AnthropicToolChoice;
// TODO(apowell): #5348 Add Gemini tool choice schema
GEMINI: OpenaiToolChoice;
};

/**
Expand Down Expand Up @@ -150,6 +152,9 @@ export const fromOpenAIToolChoice = <T extends ModelProvider>({
return openAIToolChoiceToAnthropicToolChoice.parse(
toolChoice
) as ProviderToToolChoiceMap[T];
// TODO(apowell): #5348 Add Gemini tool choice
case "GEMINI":
return toolChoice as ProviderToToolChoiceMap[T];
default:
assertUnreachable(targetProvider);
}
Expand Down
Loading

0 comments on commit cf9e7f4

Please sign in to comment.