diff --git a/api_internal_test.go b/api_internal_test.go index 9651ad402..529e7c7c4 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) { az.OrgID = c.OrgID cli := NewClientWithConfig(az) - req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil) + req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "") if err != nil { t.Errorf("Failed to create request: %v", err) } @@ -109,14 +109,16 @@ func TestRequestAuthHeader(t *testing.T) { func TestAzureFullURL(t *testing.T) { cases := []struct { - Name string - BaseURL string - Engine string - Expect string + Name string + BaseURL string + AzureModelMapper map[string]string + Model string + Expect string }{ { "AzureBaseURLWithSlashAutoStrip", "https://httpbin.org/", + nil, "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + @@ -125,6 +127,7 @@ func TestAzureFullURL(t *testing.T) { { "AzureBaseURLWithoutSlashOK", "https://httpbin.org", + nil, "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + @@ -134,10 +137,10 @@ func TestAzureFullURL(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine) + az := DefaultAzureConfig("dummy", c.BaseURL) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL("/chat/completions") + actual := cli.fullURL("/chat/completions", c.Model) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/audio.go b/audio.go index d22daf98c..12c6ccc22 100644 --- a/audio.go +++ b/audio.go @@ -68,7 +68,7 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody) if err != nil { return AudioResponse{}, err } diff --git a/chat.go b/chat.go index c09861c8c..312ef8e20 100644 --- a/chat.go +++ b/chat.go @@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 9ed0bc70a..f4fda882a 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -46,7 +46,7 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request) + req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model) if err != nil { return } diff --git a/client.go b/client.go index 0f8aa41ba..9579ba27b 100644 --- a/client.go +++ b/client.go @@ -98,8 +98,10 @@ func decodeString(body io.Reader, output *string) error { return nil } -func (c *Client) fullURL(suffix string) string { - // /openai/deployments/{engine}/chat/completions?api-version={api_version} +// fullURL returns full URL for request. +// args[0] is model name, if API type is Azure, model name is required to get deployment name. +func (c *Client) fullURL(suffix string, args ...any) string { + // /openai/deployments/{model}/chat/completions?api-version={api_version} if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") @@ -108,8 +110,17 @@ func (c *Client) fullURL(suffix string) string { if strings.Contains(suffix, "/models") { return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) } + azureDeploymentName := "UNKNOWN" + if len(args) > 0 { + model, ok := args[0].(string) + if ok { + azureDeploymentName = c.config.GetAzureDeploymentByModel(model) + } + } return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", - baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) + baseURL, azureAPIPrefix, azureDeploymentsPrefix, + azureDeploymentName, suffix, c.config.APIVersion, + ) } // c.config.APIType == APITypeOpenAI || c.config.APIType == "" @@ -120,8 +131,9 @@ func (c *Client) newStreamRequest( ctx context.Context, method string, urlSuffix string, - body any) (*http.Request, error) { - req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body) + body any, + model string) (*http.Request, error) { + req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body) if err != nil { return nil, err } diff --git a/completion.go b/completion.go index 5eec88c29..e3d1b85eb 100644 --- a/completion.go +++ b/completion.go @@ -155,7 +155,7 @@ func (c *Client) CreateCompletion( return } - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) if err != nil { return } diff --git a/config.go b/config.go index c800df15c..fbcf377c0 100644 --- a/config.go +++ b/config.go @@ -2,6 +2,7 @@ package openai import ( "net/http" + "regexp" ) const ( @@ -26,13 +27,12 @@ const AzureAPIKeyHeader = "api-key" type ClientConfig struct { authToken string - BaseURL string - OrgID string - APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD - Engine string // required when APIType is APITypeAzure or APITypeAzureAD - - HTTPClient *http.Client + BaseURL string + OrgID string + APIType APIType + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + AzureModelMapperFunc func(model string) string // replace model to azure deployment name func + HTTPClient *http.Client EmptyMessagesLimit uint } @@ -50,14 +50,16 @@ func DefaultConfig(authToken string) ClientConfig { } } -func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig { +func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { return ClientConfig{ authToken: apiKey, BaseURL: baseURL, OrgID: "", APIType: APITypeAzure, APIVersion: "2023-03-15-preview", - Engine: engine, + AzureModelMapperFunc: func(model string) string { + return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + }, HTTPClient: &http.Client{}, @@ -68,3 +70,11 @@ func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig { func (ClientConfig) String() string { return "" } + +func (c ClientConfig) GetAzureDeploymentByModel(model string) string { + if c.AzureModelMapperFunc != nil { + return c.AzureModelMapperFunc(model) + } + + return model +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 000000000..488511b11 --- /dev/null +++ b/config_test.go @@ -0,0 +1,62 @@ +package openai_test + +import ( + "testing" + + . "github.com/sashabaranov/go-openai" +) + +func TestGetAzureDeploymentByModel(t *testing.T) { + cases := []struct { + Model string + AzureModelMapperFunc func(model string) string + Expect string + }{ + { + Model: "gpt-3.5-turbo", + Expect: "gpt-35-turbo", + }, + { + Model: "gpt-3.5-turbo-0301", + Expect: "gpt-35-turbo-0301", + }, + { + Model: "text-embedding-ada-002", + Expect: "text-embedding-ada-002", + }, + { + Model: "", + Expect: "", + }, + { + Model: "models", + Expect: "models", + }, + { + Model: "gpt-3.5-turbo", + Expect: "my-gpt35", + AzureModelMapperFunc: func(model string) string { + modelmapper := map[string]string{ + "gpt-3.5-turbo": "my-gpt35", + } + if val, ok := modelmapper[model]; ok { + return val + } + return model + }, + }, + } + + for _, c := range cases { + t.Run(c.Model, func(t *testing.T) { + conf := DefaultAzureConfig("", "https://test.openai.azure.com/") + if c.AzureModelMapperFunc != nil { + conf.AzureModelMapperFunc = c.AzureModelMapperFunc + } + actual := conf.GetAzureDeploymentByModel(c.Model) + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + }) + } +} diff --git a/edits.go b/edits.go index 858a8e537..c2c8db794 100644 --- a/edits.go +++ b/edits.go @@ -2,6 +2,7 @@ package openai import ( "context" + "fmt" "net/http" ) @@ -31,7 +32,7 @@ type EditsResponse struct { // Perform an API call to the Edits endpoint. func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index 2deaccc3a..7fb432ead 100644 --- a/embeddings.go +++ b/embeddings.go @@ -132,7 +132,7 @@ type EmbeddingRequest struct { // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request) if err != nil { return } diff --git a/example_test.go b/example_test.go index da253806d..b5dfafea9 100644 --- a/example_test.go +++ b/example_test.go @@ -305,8 +305,7 @@ func Example_chatbot() { func ExampleDefaultAzureConfig() { azureKey := os.Getenv("AZURE_OPENAI_API_KEY") // Your azure API key azureEndpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") // Your azure OpenAI endpoint - azureModel := os.Getenv("AZURE_OPENAI_MODEL") // Your model deployment name - config := openai.DefaultAzureConfig(azureKey, azureEndpoint, azureModel) + config := openai.DefaultAzureConfig(azureKey, azureEndpoint) client := openai.NewClientWithConfig(config) resp, err := client.CreateChatCompletion( context.Background(), diff --git a/models_test.go b/models_test.go index 70d6d756c..b017800b9 100644 --- a/models_test.go +++ b/models_test.go @@ -40,7 +40,7 @@ func TestAzureListModels(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/", "dummyengine") + config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") config.BaseURL = ts.URL client := NewClientWithConfig(config) ctx := context.Background() diff --git a/moderation.go b/moderation.go index b386ddb95..ebd66afb9 100644 --- a/moderation.go +++ b/moderation.go @@ -63,7 +63,7 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request) if err != nil { return } diff --git a/stream.go b/stream.go index 95662db6d..cd435faea 100644 --- a/stream.go +++ b/stream.go @@ -35,7 +35,7 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request) + req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model) if err != nil { return }