diff --git a/README.md b/README.md index 38cd565e87..af1181e824 100644 --- a/README.md +++ b/README.md @@ -299,6 +299,12 @@ _Analysis with serve mode_ ``` grpcurl -plaintext -d '{"namespace": "k8sgpt", "explain": false}' localhost:8080 schema.v1.ServerService/Analyze ``` + +_Analysis with custom headers_ + +``` +k8sgpt analyze --explain --custom-headers CustomHeaderKey:CustomHeaderValue +``` ## LLM AI Backends diff --git a/cmd/analyze/analyze.go b/cmd/analyze/analyze.go index a9b3cbe3a2..69c5dbb2e5 100644 --- a/cmd/analyze/analyze.go +++ b/cmd/analyze/analyze.go @@ -38,6 +38,7 @@ var ( withDoc bool interactiveMode bool customAnalysis bool + customHeaders []string ) // AnalyzeCmd represents the problems command @@ -59,6 +60,7 @@ var AnalyzeCmd = &cobra.Command{ maxConcurrency, withDoc, interactiveMode, + customHeaders, ) if err != nil { @@ -138,5 +140,6 @@ func init() { AnalyzeCmd.Flags().BoolVarP(&interactiveMode, "interactive", "i", false, "Enable interactive mode that allows further conversation with LLM about the problem. Works only with --explain flag") // custom analysis flag AnalyzeCmd.Flags().BoolVarP(&customAnalysis, "custom-analysis", "z", false, "Enable custom analyzers") - + // add custom headers flag + AnalyzeCmd.Flags().StringSliceVarP(&customHeaders, "custom-headers", "r", []string{}, "Custom Headers, : (e.g CustomHeaderKey:CustomHeaderValue AnotherHeader:AnotherValue)") } diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index d158882b13..38c8500346 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -15,6 +15,7 @@ package ai import ( "context" + "net/http" ) var ( @@ -83,6 +84,7 @@ type IAIConfig interface { GetProviderId() string GetCompartmentId() string GetOrganizationId() string + GetCustomHeaders() []http.Header } func NewClient(provider string) IAI { @@ -101,22 +103,23 @@ type AIConfiguration struct { } type AIProvider struct { - Name string `mapstructure:"name"` - Model string `mapstructure:"model"` - Password string `mapstructure:"password" yaml:"password,omitempty"` - BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"` - ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"` - ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"` - EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"` - Engine string `mapstructure:"engine" yaml:"engine,omitempty"` - Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"` - ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"` - ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"` - CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"` - TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"` - TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"` - MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"` - OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"` + Name string `mapstructure:"name"` + Model string `mapstructure:"model"` + Password string `mapstructure:"password" yaml:"password,omitempty"` + BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"` + ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"` + ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"` + EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"` + Engine string `mapstructure:"engine" yaml:"engine,omitempty"` + Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"` + ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"` + ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"` + CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"` + TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"` + TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"` + MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"` + OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"` + CustomHeaders []http.Header `mapstructure:"customHeaders"` } func (p *AIProvider) GetBaseURL() string { @@ -174,6 +177,10 @@ func (p *AIProvider) GetOrganizationId() string { return p.OrganizationId } +func (p *AIProvider) GetCustomHeaders() []http.Header { + return p.CustomHeaders +} + var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"} func NeedPassword(backend string) bool { diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index 8ed0f0ced8..727d39ec23 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -52,24 +52,27 @@ func (c *OpenAIClient) Configure(config IAIConfig) error { defaultConfig.BaseURL = baseURL } + transport := &http.Transport{} if proxyEndpoint != "" { proxyUrl, err := url.Parse(proxyEndpoint) if err != nil { return err } - transport := &http.Transport{ - Proxy: http.ProxyURL(proxyUrl), - } - - defaultConfig.HTTPClient = &http.Client{ - Transport: transport, - } + transport.Proxy = http.ProxyURL(proxyUrl) } if orgId != "" { defaultConfig.OrgID = orgId } + customHeaders := config.GetCustomHeaders() + defaultConfig.HTTPClient = &http.Client{ + Transport: &OpenAIHeaderTransport{ + Origin: transport, + Headers: customHeaders, + }, + } + client := openai.NewClientWithConfig(defaultConfig) if client == nil { return errors.New("error creating OpenAI client") @@ -106,3 +109,25 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string func (c *OpenAIClient) GetName() string { return openAIClientName } + +// OpenAIHeaderTransport is an http.RoundTripper that adds the given headers to each request. +type OpenAIHeaderTransport struct { + Origin http.RoundTripper + Headers []http.Header +} + +// RoundTrip implements the http.RoundTripper interface. +func (t *OpenAIHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone the request to avoid modifying the original request + clonedReq := req.Clone(req.Context()) + for _, header := range t.Headers { + for key, values := range header { + // Possible values per header: RFC 2616 + for _, value := range values { + clonedReq.Header.Add(key, value) + } + } + } + + return t.Origin.RoundTrip(clonedReq) +} diff --git a/pkg/ai/openai_header_transport_test.go b/pkg/ai/openai_header_transport_test.go new file mode 100644 index 0000000000..9d43f463f5 --- /dev/null +++ b/pkg/ai/openai_header_transport_test.go @@ -0,0 +1,106 @@ +package ai + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Mock configuration +type mockConfig struct { + baseURL string +} + +func (m *mockConfig) GetPassword() string { + return "" +} + +func (m *mockConfig) GetOrganizationId() string { + return "" +} + +func (m *mockConfig) GetProxyEndpoint() string { + return "" +} + +func (m *mockConfig) GetBaseURL() string { + return m.baseURL +} + +func (m *mockConfig) GetCustomHeaders() []http.Header { + return []http.Header{ + {"X-Custom-Header-1": []string{"Value1"}}, + {"X-Custom-Header-2": []string{"Value2"}}, + {"X-Custom-Header-2": []string{"Value3"}}, // Testing multiple values for the same header + } +} + +func (m *mockConfig) GetModel() string { + return "" +} + +func (m *mockConfig) GetTemperature() float32 { + return 0.0 +} + +func (m *mockConfig) GetTopP() float32 { + return 0.0 +} +func (m *mockConfig) GetCompartmentId() string { + return "" +} + +func (m *mockConfig) GetTopK() int32 { + return 0.0 +} + +func (m *mockConfig) GetMaxTokens() int { + return 0 +} + +func (m *mockConfig) GetEndpointName() string { + return "" +} +func (m *mockConfig) GetEngine() string { + return "" +} + +func (m *mockConfig) GetProviderId() string { + return "" +} + +func (m *mockConfig) GetProviderRegion() string { + return "" +} + +func TestOpenAIClient_CustomHeaders(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "Value1", r.Header.Get("X-Custom-Header-1")) + assert.ElementsMatch(t, []string{"Value2", "Value3"}, r.Header["X-Custom-Header-2"]) + w.WriteHeader(http.StatusOK) + // Mock response for openai completion + mockResponse := `{"choices": [{"message": {"content": "test"}}]}` + n, err := w.Write([]byte(mockResponse)) + if err != nil { + t.Fatalf("error writing response: %v", err) + } + if n != len(mockResponse) { + t.Fatalf("expected to write %d bytes but wrote %d bytes", len(mockResponse), n) + } + })) + defer server.Close() + + config := &mockConfig{baseURL: server.URL} + + client := &OpenAIClient{} + err := client.Configure(config) + assert.NoError(t, err) + + // Make a completion request to trigger the headers + ctx := context.Background() + _, err = client.GetCompletion(ctx, "foo prompt") + assert.NoError(t, err) +} diff --git a/pkg/analysis/analysis.go b/pkg/analysis/analysis.go index 17771233e8..5ee36cbec0 100644 --- a/pkg/analysis/analysis.go +++ b/pkg/analysis/analysis.go @@ -79,6 +79,7 @@ func NewAnalysis( maxConcurrency int, withDoc bool, interactiveMode bool, + httpHeaders []string, ) (*Analysis, error) { // Get kubernetes client from viper. kubecontext := viper.GetString("kubecontext") @@ -146,6 +147,8 @@ func NewAnalysis( } aiClient := ai.NewClient(aiProvider.Name) + customHeaders := util.NewHeaders(httpHeaders) + aiProvider.CustomHeaders = customHeaders if err := aiClient.Configure(&aiProvider); err != nil { return nil, err } diff --git a/pkg/server/analyze.go b/pkg/server/analyze.go index edd4dafaa0..26b1189261 100644 --- a/pkg/server/analyze.go +++ b/pkg/server/analyze.go @@ -28,8 +28,9 @@ func (h *handler) Analyze(ctx context.Context, i *schemav1.AnalyzeRequest) ( i.Nocache, i.Explain, int(i.MaxConcurrency), - false, // Kubernetes Doc disabled in server mode - false, // Interactive mode disabled in server mode + false, // Kubernetes Doc disabled in server mode + false, // Interactive mode disabled in server mode + []string{}, //TODO: add custom http headers in server mode ) config.Context = ctx // Replace context for correct timeouts. if err != nil { diff --git a/pkg/util/util.go b/pkg/util/util.go index 67852c6c3e..4babf28a57 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -21,6 +21,7 @@ import ( "encoding/hex" "errors" "fmt" + "net/http" "os" "regexp" "strings" @@ -261,3 +262,36 @@ func FetchLatestEvent(ctx context.Context, kubernetesClient *kubernetes.Client, } return latestEvent, nil } + +// NewHeaders parses a slice of strings in the format "key:value" into []http.Header +// It handles headers with the same key by appending values +func NewHeaders(customHeaders []string) []http.Header { + headers := make(map[string][]string) + + for _, header := range customHeaders { + vals := strings.SplitN(header, ":", 2) + if len(vals) != 2 { + //TODO: Handle error instead of ignoring it + continue + } + key := strings.TrimSpace(vals[0]) + value := strings.TrimSpace(vals[1]) + + if _, ok := headers[key]; !ok { + headers[key] = []string{} + } + headers[key] = append(headers[key], value) + } + + // Convert map to []http.Header format + var result []http.Header + for key, values := range headers { + header := make(http.Header) + for _, value := range values { + header.Add(key, value) + } + result = append(result, header) + } + + return result +}