Skip to content

Allow using 3rd party AI services that are compatible with OpenAI API format in the openai-gpt agent #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions shell/agents/AIShell.OpenAI.Agent/GPT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ internal enum EndpointType
{
AzureOpenAI,
OpenAI,
CompatibleThirdParty,
}

public class GPT
Expand Down Expand Up @@ -56,9 +57,16 @@ public GPT(
bool noDeployment = string.IsNullOrEmpty(Deployment);
Type = noEndpoint && noDeployment
? EndpointType.OpenAI
: !noEndpoint && !noDeployment
? EndpointType.AzureOpenAI
: throw new InvalidOperationException($"Invalid setting: {(noEndpoint ? "Endpoint" : "Deployment")} key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys.");
: !noEndpoint && noDeployment
? EndpointType.CompatibleThirdParty
: !noEndpoint && !noDeployment
? EndpointType.AzureOpenAI
: throw new InvalidOperationException($"Invalid setting: 'Deployment' key present but 'Endpoint' key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys.");

if (ModelInfo is null && Type is EndpointType.CompatibleThirdParty)
{
ModelInfo = ModelInfo.ThirdPartyModel;
}
}

/// <summary>
Expand Down Expand Up @@ -142,11 +150,18 @@ private void ShowEndpointInfo(IHost host)
new(label: " Model", m => m.ModelName),
},

EndpointType.OpenAI => new CustomElement<GPT>[]
{
EndpointType.OpenAI =>
[
new(label: " Type", m => m.Type.ToString()),
new(label: " Model", m => m.ModelName),
},
],

EndpointType.CompatibleThirdParty =>
[
new(label: " Type", m => m.Type.ToString()),
new(label: " Endpoint", m => m.Endpoint),
new(label: " Model", m => m.ModelName),
],

_ => throw new UnreachableException(),
};
Expand Down
5 changes: 5 additions & 0 deletions shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ internal class ModelInfo
private static readonly Dictionary<string, ModelInfo> s_modelMap;
private static readonly Dictionary<string, Task<Tokenizer>> s_encodingMap;

// A rough estimate to cover all third-party models.
// - most popular models today support 32K+ context length;
// - use the gpt-4o encoding as an estimate for token count.
internal static readonly ModelInfo ThirdPartyModel = new(32_000, encoding: Gpt4oEncoding);

static ModelInfo()
{
// For reference, see https://platform.openai.com/docs/models and the "Counting tokens" section in
Expand Down
8 changes: 7 additions & 1 deletion shell/agents/AIShell.OpenAI.Agent/Service.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ private void RefreshOpenAIClient()
return;
}

EndpointType type = _gptToUse.Type;
string userKey = Utils.ConvertFromSecureString(_gptToUse.Key);

if (_gptToUse.Type is EndpointType.AzureOpenAI)
if (type is EndpointType.AzureOpenAI)
{
// Create a client that targets Azure OpenAI service or Azure API Management service.
var clientOptions = new AzureOpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() };
Expand Down Expand Up @@ -152,6 +153,11 @@ private void RefreshOpenAIClient()
{
// Create a client that targets the non-Azure OpenAI service.
var clientOptions = new OpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() };
if (type is EndpointType.CompatibleThirdParty)
{
clientOptions.Endpoint = new(_gptToUse.Endpoint);
}

var aiClient = new OpenAIClient(new ApiKeyCredential(userKey), clientOptions);
_client = aiClient.GetChatClient(_gptToUse.ModelName);
}
Expand Down
Loading