Skip to content

Commit

Permalink
feat: support openai organization Id (k8sgpt-ai#1133)
Browse files Browse the repository at this point in the history
* feat: add organization flag

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>

* feat: add orgId on openai backend

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>

---------

Signed-off-by: JuHyung-Son <sonju0427@gmail.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
  • Loading branch information
JuHyung-Son and AlexsJones authored Jun 14, 2024
1 parent c834c09 commit 4867d39
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 44 deletions.
3 changes: 3 additions & 0 deletions cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ var addCmd = &cobra.Command{
TopP: topP,
TopK: topK,
MaxTokens: maxTokens,
OrganizationId: organizationId,
}

if providerIndex == -1 {
Expand Down Expand Up @@ -176,4 +177,6 @@ func init() {
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)")
//add flag for OCI Compartment ID
addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)")
// add flag for openai organization
addCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "OpenAI or AzureOpenAI Organization ID (only for openai and azureopenai backend)")
}
1 change: 1 addition & 0 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
topP float32
topK int32
maxTokens int
organizationId string
)

var configAI ai.AIConfiguration
Expand Down
80 changes: 43 additions & 37 deletions cmd/auth/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ var updateCmd = &cobra.Command{
Use: "update",
Short: "Update a backend provider",
Long: "The command to update an AI backend provider",
Args: cobra.ExactArgs(1),
// Args: cobra.ExactArgs(1),
PreRun: func(cmd *cobra.Command, args []string) {
backend, _ := cmd.Flags().GetString("backend")
if strings.ToLower(backend) == "azureopenai" {
_ = cmd.MarkFlagRequired("engine")
_ = cmd.MarkFlagRequired("baseurl")
}
organizationId, _ := cmd.Flags().GetString("organizationId")
if strings.ToLower(backend) != "azureopenai" && strings.ToLower(backend) != "openai" {
if organizationId != "" {
color.Red("Error: organizationId must be empty for backends other than azureopenai or openai.")
os.Exit(1)
}
}
},
Run: func(cmd *cobra.Command, args []string) {

Expand All @@ -43,50 +50,47 @@ var updateCmd = &cobra.Command{
os.Exit(1)
}

inputBackends := strings.Split(args[0], ",")
backend, _ := cmd.Flags().GetString("backend")

if len(inputBackends) == 0 {
color.Red("Error: backend must be set.")
os.Exit(1)
}
if temperature > 1.0 || temperature < 0.0 {
color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1)
}

for _, b := range inputBackends {
foundBackend := false
for i, provider := range configAI.Providers {
if b == provider.Name {
foundBackend = true
if backend != "" {
configAI.Providers[i].Name = backend
color.Blue("Backend name updated successfully")
}
if model != "" {
configAI.Providers[i].Model = model
color.Blue("Model updated successfully")
}
if password != "" {
configAI.Providers[i].Password = password
color.Blue("Password updated successfully")
}
if baseURL != "" {
configAI.Providers[i].BaseURL = baseURL
color.Blue("Base URL updated successfully")
}
if engine != "" {
configAI.Providers[i].Engine = engine
}
configAI.Providers[i].Temperature = temperature
color.Green("%s updated in the AI backend provider list", b)
foundBackend := false
for i, provider := range configAI.Providers {
if backend == provider.Name {
foundBackend = true
if backend != "" {
configAI.Providers[i].Name = backend
color.Blue("Backend name updated successfully")
}
if model != "" {
configAI.Providers[i].Model = model
color.Blue("Model updated successfully")
}
if password != "" {
configAI.Providers[i].Password = password
color.Blue("Password updated successfully")
}
if baseURL != "" {
configAI.Providers[i].BaseURL = baseURL
color.Blue("Base URL updated successfully")
}
if engine != "" {
configAI.Providers[i].Engine = engine
}
if organizationId != "" {
configAI.Providers[i].OrganizationId = organizationId
color.Blue("Organization Id updated successfully")
}
configAI.Providers[i].Temperature = temperature
color.Green("%s updated in the AI backend provider list", backend)
}
if !foundBackend {
color.Red("Error: %s does not exist in configuration file. Please use k8sgpt auth new.", args[0])
os.Exit(1)
}

}
if !foundBackend {
color.Red("Error: %s does not exist in configuration file. Please use k8sgpt auth new.", args[0])
os.Exit(1)
}

viper.Set("ai", configAI)
Expand All @@ -110,4 +114,6 @@ func init() {
updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// update flag for azure open ai engine/deployment name
updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name")
// update flag for organizationId
updateCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "Update OpenAI or Azure organization Id")
}
12 changes: 9 additions & 3 deletions pkg/ai/azureopenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ const azureAIClientName = "azureopenai"
type AzureAIClient struct {
nopCloser

client *openai.Client
model string
temperature float32
client *openai.Client
model string
temperature float32
organizationId string
}

func (c *AzureAIClient) Configure(config IAIConfig) error {
Expand All @@ -25,6 +26,7 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
engine := config.GetEngine()
proxyEndpoint := config.GetProxyEndpoint()
defaultConfig := openai.DefaultAzureConfig(token, baseURL)
orgId := config.GetOrganizationId()

defaultConfig.AzureModelMapperFunc = func(model string) string {
// If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function
Expand All @@ -48,6 +50,10 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
Transport: transport,
}
}
if orgId != "" {
defaultConfig.OrgID = orgId
}

client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating Azure OpenAI client")
Expand Down
6 changes: 6 additions & 0 deletions pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type IAIConfig interface {
GetMaxTokens() int
GetProviderId() string
GetCompartmentId() string
GetOrganizationId() string
}

func NewClient(provider string) IAI {
Expand Down Expand Up @@ -111,6 +112,7 @@ type AIProvider struct {
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
}

func (p *AIProvider) GetBaseURL() string {
Expand Down Expand Up @@ -164,6 +166,10 @@ func (p *AIProvider) GetCompartmentId() string {
return p.CompartmentId
}

func (p *AIProvider) GetOrganizationId() string {
return p.OrganizationId
}

var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}

func NeedPassword(backend string) bool {
Expand Down
14 changes: 10 additions & 4 deletions pkg/ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ const openAIClientName = "openai"
type OpenAIClient struct {
nopCloser

client *openai.Client
model string
temperature float32
topP float32
client *openai.Client
model string
temperature float32
topP float32
organizationId string
}

const (
Expand All @@ -43,6 +44,7 @@ const (
func (c *OpenAIClient) Configure(config IAIConfig) error {
token := config.GetPassword()
defaultConfig := openai.DefaultConfig(token)
orgId := config.GetOrganizationId()
proxyEndpoint := config.GetProxyEndpoint()

baseURL := config.GetBaseURL()
Expand All @@ -64,6 +66,10 @@ func (c *OpenAIClient) Configure(config IAIConfig) error {
}
}

if orgId != "" {
defaultConfig.OrgID = orgId
}

client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating OpenAI client")
Expand Down

0 comments on commit 4867d39

Please sign in to comment.