Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ import (
var ErrInvalidMethod = errors.New("webhook only supports HTTP methods PUT or POST")

type clientConfiguration struct {
userAgent string
dialer net.Dialer // We use Dialer here instead of DialContext as our mqtt client doesn't support DialContext.
customDialer bool
httpClientConfig HTTPClientConfig
userAgent string
dialer net.Dialer // We use Dialer here instead of DialContext as our mqtt client doesn't support DialContext.
customDialer bool
}

// defaultDialTimeout is the default timeout for the dialer, 30 seconds to match http.DefaultTransport.
Expand All @@ -39,7 +38,7 @@ type Client struct {
oauth2TokenSource oauth2.TokenSource
}

func NewClient(opts ...ClientOption) (*Client, error) {
func NewClient(httpClientConfig *HTTPClientConfig, opts ...ClientOption) (*Client, error) {
cfg := clientConfiguration{
userAgent: "Grafana",
dialer: net.Dialer{},
Expand All @@ -58,13 +57,13 @@ func NewClient(opts ...ClientOption) (*Client, error) {
cfg: cfg,
}

if cfg.httpClientConfig.OAuth2 != nil {
if err := ValidateOAuth2Config(cfg.httpClientConfig.OAuth2); err != nil {
if httpClientConfig != nil && httpClientConfig.OAuth2 != nil {
if err := ValidateOAuth2Config(httpClientConfig.OAuth2); err != nil {
return nil, fmt.Errorf("invalid OAuth2 configuration: %w", err)
}
// If the user has provided an OAuth2 config, we need to prepare the OAuth2 token source. This needs to
// be stored outside of the request so that the token expiration/re-use will work as expected.
tokenSource, err := NewOAuth2TokenSource(cfg)
tokenSource, err := NewOAuth2TokenSource(*httpClientConfig.OAuth2, cfg)
if err != nil {
return nil, err
}
Expand All @@ -89,14 +88,6 @@ func WithDialer(dialer net.Dialer) ClientOption {
}
}

func WithHTTPClientConfig(config *HTTPClientConfig) ClientOption {
return func(c *clientConfiguration) {
if config != nil {
c.httpClientConfig = *config
}
}
}

func ToHTTPClientOption(option ...ClientOption) []commoncfg.HTTPClientOption {
cfg := clientConfiguration{}
for _, opt := range option {
Expand Down
38 changes: 18 additions & 20 deletions http/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,28 @@ import (

func TestClient(t *testing.T) {
t.Run("NewClient", func(t *testing.T) {
client, err := NewClient()
client, err := NewClient(nil)
require.NoError(t, err)
require.NotNil(t, client)
})

t.Run("WithUserAgent", func(t *testing.T) {
client, err := NewClient(WithUserAgent("TEST"))
client, err := NewClient(nil, WithUserAgent("TEST"))
require.NoError(t, err)
require.Equal(t, "TEST", client.cfg.userAgent)
})

t.Run("WithDialer with timeout", func(t *testing.T) {
dialer := net.Dialer{Timeout: 5 * time.Second}
client, err := NewClient(WithDialer(dialer))
client, err := NewClient(nil, WithDialer(dialer))
require.NoError(t, err)
require.Equal(t, dialer, client.cfg.dialer)
})

t.Run("WithDialer missing timeout should use default", func(t *testing.T) {
// Mostly defensive to ensure that some timeout is set.
dialer := net.Dialer{LocalAddr: &net.TCPAddr{IP: net.ParseIP("::")}}
client, err := NewClient(WithDialer(dialer))
client, err := NewClient(nil, WithDialer(dialer))
require.NoError(t, err)

expectedDialer := dialer
Expand All @@ -58,12 +58,12 @@ func TestClient(t *testing.T) {
ClientSecret: "test-client-secret",
TokenURL: "https://localhost:8080/oauth2/token",
}
client, err := NewClient(WithHTTPClientConfig(&HTTPClientConfig{
client, err := NewClient(&HTTPClientConfig{
OAuth2: oauth2Config,
}))
})
require.NoError(t, err)

require.Equal(t, oauth2Config, client.cfg.httpClientConfig.OAuth2)
require.NotNil(t, client.oauth2TokenSource)
})

t.Run("WithOAuth2 invalid TLS", func(t *testing.T) {
Expand All @@ -75,9 +75,9 @@ func TestClient(t *testing.T) {
CACertificate: "invalid-ca-cert",
},
}
_, err := NewClient(WithHTTPClientConfig(&HTTPClientConfig{
_, err := NewClient(&HTTPClientConfig{
OAuth2: oauth2Config,
}))
})
require.ErrorIs(t, err, ErrOAuth2TLSConfigInvalid)
})
}
Expand All @@ -92,7 +92,7 @@ func TestSendWebhook(t *testing.T) {
got = r
w.WriteHeader(http.StatusOK)
}))
s, err := NewClient(WithUserAgent("TEST"))
s, err := NewClient(nil, WithUserAgent("TEST"))
require.NoError(t, err)

// The method should be either POST or PUT.
Expand Down Expand Up @@ -177,7 +177,7 @@ func TestSendWebhookHMAC(t *testing.T) {
server := initServer(httptest.NewServer)
defer server.Close()

client, err := NewClient()
client, err := NewClient(nil)
require.NoError(t, err)
webhook := &receivers.SendWebhookSettings{
URL: server.URL,
Expand Down Expand Up @@ -209,7 +209,7 @@ func TestSendWebhookHMAC(t *testing.T) {
cfg, err := tlsConfig.ToCryptoTLSConfig()
require.NoError(t, err)

client, err := NewClient()
client, err := NewClient(nil)
require.NoError(t, err)
webhook := &receivers.SendWebhookSettings{
URL: server.URL,
Expand Down Expand Up @@ -419,8 +419,8 @@ func TestSendWebhookOAuth2(t *testing.T) {
oauth2Config: OAuth2Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
ProxyConfig: ProxyConfig{
ProxyURL: mustURL("http://<server>.com"), // This will be replaced with the test server URL.
ProxyConfig: &ProxyConfig{
ProxyURL: MustURL("http://<server>.com"), // This will be replaced with the test server URL.
},
},
oauth2Response: oauth2Response{
Expand Down Expand Up @@ -484,17 +484,15 @@ func TestSendWebhookOAuth2(t *testing.T) {
oauthConfig := tc.oauth2Config
oauthConfig.TokenURL = tokenURL

if oauthConfig.ProxyConfig.ProxyURL.URL != nil && oauthConfig.ProxyConfig.ProxyURL.String() != "" {
oauthConfig.ProxyConfig.ProxyURL = mustURL(proxyServer.URL)
if oauthConfig.ProxyConfig != nil {
oauthConfig.ProxyConfig.ProxyURL = MustURL(proxyServer.URL)
}
expectedProxyRequestCnt := 0
if tc.expProxyRequests {
expectedProxyRequestCnt = 1
}

client, err := NewClient(append(tc.otherClientOpts, WithHTTPClientConfig(&HTTPClientConfig{
OAuth2: &oauthConfig,
}))...)
client, err := NewClient(&HTTPClientConfig{OAuth2: &oauthConfig}, tc.otherClientOpts...)
if tc.expClientError != nil {
assert.ErrorIs(t, err, tc.expClientError, "expected client creation error to match")
return
Expand Down Expand Up @@ -549,5 +547,5 @@ func TestToHTTPClientOption(t *testing.T) {
// Verify number of fields using reflection
tp := reflect.TypeOf(clientConfiguration{})
// You need to increase the number of fields covered in this test, if you add a new field to the configuration struct.
require.Equalf(t, 4, tp.NumField(), "Not all fields are converted to HTTPClientOption, which means that the configuration will not be supported in upstream integrations")
require.Equalf(t, 3, tp.NumField(), "Not all fields are converted to HTTPClientOption, which means that the configuration will not be supported in upstream integrations")
}
13 changes: 2 additions & 11 deletions http/config_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package http

import (
"net/url"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -61,7 +60,7 @@ func TestProxyConfigValidation(t *testing.T) {
{
name: "valid proxy config with URL",
cfg: ProxyConfig{
ProxyURL: mustURL("http://proxy.example.com:8080"),
ProxyURL: MustURL("http://proxy.example.com:8080"),
},
wantErr: false,
},
Expand All @@ -73,7 +72,7 @@ func TestProxyConfigValidation(t *testing.T) {
{
name: "invalid proxy URL and environment",
cfg: ProxyConfig{
ProxyURL: mustURL("http://proxy.example.com:8080"),
ProxyURL: MustURL("http://proxy.example.com:8080"),
ProxyFromEnvironment: true,
},
wantErr: true,
Expand Down Expand Up @@ -113,11 +112,3 @@ func TestProxyConfigValidation(t *testing.T) {
})
}
}

func mustURL(u string) URL {
res, err := url.Parse(u)
if err != nil {
panic(err)
}
return URL{URL: res}
}
15 changes: 6 additions & 9 deletions http/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type OAuth2Config struct {
Scopes []string `json:"scopes,omitempty" yaml:"scopes,omitempty"`
EndpointParams map[string]string `json:"endpoint_params,omitempty" yaml:"endpoint_params,omitempty"`
TLSConfig *receivers.TLSConfig `json:"tls_config,omitempty" yaml:"tls_config,omitempty"`
ProxyConfig ProxyConfig `json:"proxy_config" yaml:"proxy_config"`
ProxyConfig *ProxyConfig `json:"proxy_config,omitempty" yaml:"proxy_config,omitempty"`
}

func ValidateOAuth2Config(config *OAuth2Config) error {
Expand All @@ -57,8 +57,10 @@ func ValidateOAuth2Config(config *OAuth2Config) error {
}
}

if err := ValidateProxyConfig(config.ProxyConfig); err != nil {
return fmt.Errorf("%w: %w", ErrInvalidProxyConfig, err)
if config.ProxyConfig != nil {
if err := ValidateProxyConfig(*config.ProxyConfig); err != nil {
return fmt.Errorf("%w: %w", ErrInvalidProxyConfig, err)
}
}

return nil
Expand All @@ -77,12 +79,7 @@ func NewOAuth2RoundTripper(tokenSource oauth2.TokenSource, next http.RoundTrippe
}
}

func NewOAuth2TokenSource(clientConfig clientConfiguration) (oauth2.TokenSource, error) {
config := clientConfig.httpClientConfig.OAuth2
if config == nil {
// This should never happen, but we add this check defensively.
return nil, fmt.Errorf("OAuth2 configuration is required")
}
func NewOAuth2TokenSource(config OAuth2Config, clientConfig clientConfiguration) (oauth2.TokenSource, error) {
credconfig := &clientcredentials.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Expand Down
2 changes: 1 addition & 1 deletion http/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestValidateOAuth2Config(t *testing.T) {
ClientID: "client-id",
ClientSecret: "client-secret",
TokenURL: "https://example.com/token",
ProxyConfig: ProxyConfig{
ProxyConfig: &ProxyConfig{
NoProxy: "localhost",
},
},
Expand Down
10 changes: 10 additions & 0 deletions http/testing.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package http

import "net/url"

const TestCACert = `-----BEGIN CERTIFICATE-----
MIGrMF+gAwIBAgIBATAFBgMrZXAwADAeFw0yNDExMTYxMDI4MzNaFw0yNTExMTYx
MDI4MzNaMAAwKjAFBgMrZXADIQCf30GvRnHbs9gukA3DLXDK6W5JVgYw6mERU/60
Expand All @@ -22,3 +24,11 @@ MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`

func MustURL(u string) URL {
res, err := url.Parse(u)
if err != nil {
panic(err)
}
return URL{URL: res}
}
Loading