Skip to content

Commit

Permalink
feat (config): add support for a http.RoundTripper (#137)
Browse files Browse the repository at this point in the history
Add support for specifying an optional http.RoundTripper
for a provider config.  If specified the http
client will use the RoundTripper when making
requests to the provider.
  • Loading branch information
jimlambrt authored Aug 1, 2024
1 parent 36b85f9 commit 3dae6e2
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 9 deletions.
40 changes: 38 additions & 2 deletions oidc/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"hash"
"hash/fnv"
"net/http"
"net/url"
"reflect"
"runtime"
Expand Down Expand Up @@ -89,9 +90,15 @@ type Config struct {

// ProviderCA is an optional CA certs (PEM encoded) to use when sending
// requests to the provider. If you have a list of *x509.Certificates, then
// see EncodeCertificates(...) to PEM encode them.
// see EncodeCertificates(...) to PEM encode them. Note: specifying both
// ProviderCA and RoundTripper is an error.
ProviderCA string

// RoundTripper is an optional http.RoundTripper to use when sending requests
// to the provider. Note: specifying both ProviderCA and RoundTripper is an
// error.
RoundTripper http.RoundTripper

// NowFunc is a time func that returns the current time.
NowFunc func() time.Time `json:"-"`

Expand All @@ -118,6 +125,7 @@ func NewConfig(issuer string, clientID string, clientSecret ClientSecret, suppor
SupportedSigningAlgs: supported,
Scopes: opts.withScopes,
ProviderCA: opts.withProviderCA,
RoundTripper: opts.withRoundTripper,
Audiences: opts.withAudiences,
NowFunc: opts.withNowFunc,
AllowedRedirectURLs: allowedRedirectURLs,
Expand Down Expand Up @@ -168,6 +176,16 @@ func (c *Config) Hash() (uint64, error) {
args = append(args, audiences...)
args = append(args, redirects...)

if c.RoundTripper != nil {
v := reflect.ValueOf(c.RoundTripper)
switch {
case v.CanAddr():
args = append(args, v.Addr().String())
default:
args = append(args, v.String())
}
}

if c.ProviderConfig != nil {
args = append(
args,
Expand Down Expand Up @@ -269,6 +287,9 @@ func (c *Config) Validate() error {
return fmt.Errorf("%s: %w", op, ErrInvalidCACert)
}
}
if c.ProviderCA != "" && c.RoundTripper != nil {
return fmt.Errorf("%s: you cannot specify both a ProviderCA and RoundTripper: %w", op, ErrInvalidParameter)
}

if c.ProviderConfig != nil {
switch {
Expand Down Expand Up @@ -300,6 +321,7 @@ type configOptions struct {
withProviderCA string
withNowFunc func() time.Time
withProviderConfig *ProviderConfig
withRoundTripper http.RoundTripper
}

// configDefaults is a handy way to get the defaults at runtime and
Expand All @@ -319,12 +341,14 @@ func getConfigOpts(opt ...Option) configOptions {
}

// WithProviderCA provides optional CA certs (PEM encoded) for the provider's
// config. These certs will can be used when making http requests to the
// config. These certs will be used when making http requests to the
// provider.
//
// Valid for: Config
//
// See EncodeCertificates(...) to PEM encode a number of certs.
//
// Note: specifying both WithProviderCA and WithRoundTripper is a error.
func WithProviderCA(cert string) Option {
return func(o interface{}) {
if o, ok := o.(*configOptions); ok {
Expand All @@ -333,6 +357,18 @@ func WithProviderCA(cert string) Option {
}
}

// WithRoundTripper provides and optional RoundTripper for the provider's
// config. This RoundTripper will be used when making http requests to the
// provider. Note: specifying both WithProviderCA and WithRoundTripper is a
// error.
func WithRoundTripper(rt http.RoundTripper) Option {
return func(o interface{}) {
if o, ok := o.(*configOptions); ok {
o.withRoundTripper = rt
}
}
}

// EncodeCertificates will encode a number of x509 certificates to PEM. It will
// help encode certs for use with the WithProviderCA(...) option.
func EncodeCertificates(certs ...*x509.Certificate) (string, error) {
Expand Down
141 changes: 140 additions & 1 deletion oidc/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/x509"
"errors"
"fmt"
"net/http"
"testing"
"time"

Expand Down Expand Up @@ -44,6 +45,8 @@ func TestNewConfig(t *testing.T) {
return time.Now().Add(-1 * time.Minute)
}

testRt := newTestRoundTripper(t)

type args struct {
issuer string
clientID string
Expand All @@ -61,7 +64,7 @@ func TestNewConfig(t *testing.T) {
wantErrContains string
}{
{
name: "valid-with-all-valid-opts",
name: "valid-with-all-valid-opts-except-with-round-tripper",
args: args{
issuer: "http://your_issuer/",
clientID: "your_client_id",
Expand Down Expand Up @@ -103,6 +106,49 @@ func TestNewConfig(t *testing.T) {
},
},
},
{
name: "with-round-tripper",
args: args{
issuer: "http://your_issuer/",
clientID: "your_client_id",
clientSecret: "your_client_secret",
supported: []Alg{RS512},
allowedRedirectURLs: []string{"http://your_redirect_url", "http://redirect_url_two", "http://redirect_url_three"},
opt: []Option{
WithAudiences("your_aud1", "your_aud2"),
WithScopes("email", "profile"),
WithRoundTripper(testRt),
WithNow(testNow),
WithProviderConfig(&ProviderConfig{
AuthURL: "https://auth-endpoint",
JWKSURL: "https://jwks-endpoint",
TokenURL: "https://token-endpoint",
UserInfoURL: "https://userinfo-endpoint",
}),
},
},
want: &Config{
Issuer: "http://your_issuer/",
ClientID: "your_client_id",
ClientSecret: "your_client_secret",
SupportedSigningAlgs: []Alg{RS512},
Audiences: []string{"your_aud1", "your_aud2"},
Scopes: []string{oidc.ScopeOpenID, "email", "profile"},
RoundTripper: testRt,
NowFunc: testNow,
AllowedRedirectURLs: []string{
"http://your_redirect_url",
"http://redirect_url_two",
"http://redirect_url_three",
},
ProviderConfig: &ProviderConfig{
AuthURL: "https://auth-endpoint",
JWKSURL: "https://jwks-endpoint",
TokenURL: "https://token-endpoint",
UserInfoURL: "https://userinfo-endpoint",
},
},
},
{
name: "missing-provider-config-auth-url",
args: args{
Expand Down Expand Up @@ -282,6 +328,22 @@ func TestNewConfig(t *testing.T) {
wantErr: true,
wantIsErr: ErrInvalidCACert,
},
{
name: "invalid-both-cert-and-round-tripper",
args: args{
issuer: "http://your_issuer/",
clientID: "your_client_id",
clientSecret: "your_client_secret",
supported: []Alg{RS512},
allowedRedirectURLs: []string{"http://your_redirect_url"},
opt: []Option{
WithProviderCA(testCaPem),
WithRoundTripper(testRt),
},
},
wantErr: true,
wantIsErr: ErrInvalidParameter,
},
{
name: "invalid-alg",
args: args{
Expand Down Expand Up @@ -430,6 +492,7 @@ func TestConfig_Hash(t *testing.T) {
require.NoError(t, err)
return c
}
testRt := newTestRoundTripper(t)
tests := []struct {
name string
c1 *Config
Expand Down Expand Up @@ -473,6 +536,42 @@ func TestConfig_Hash(t *testing.T) {
),
wantEqual: true,
},
{
name: "equal-with-round-tripper",
c1: newCfg(
"https://www.alice.com",
"client-id", "client-secret",
[]Alg{RS256},
[]string{"www.alice.com/callback", "www.bob.com/callback"},
WithScopes("email", "profile"),
WithAudiences("alice.com", "bob.com"),
WithRoundTripper(testRt),
WithNow(time.Now),
WithProviderConfig(&ProviderConfig{
AuthURL: "https://auth-endpoint",
JWKSURL: "https://jwks-endpoint",
TokenURL: "https://token-endpoint",
UserInfoURL: "https://userinfo-endpoint",
}),
),
c2: newCfg(
"https://www.alice.com",
"client-id", "client-secret",
[]Alg{RS256},
[]string{"www.bob.com/callback", "www.alice.com/callback"},
WithScopes("profile", "email"),
WithAudiences("bob.com", "alice.com"),
WithRoundTripper(testRt),
WithNow(time.Now),
WithProviderConfig(&ProviderConfig{
AuthURL: "https://auth-endpoint",
JWKSURL: "https://jwks-endpoint",
TokenURL: "https://token-endpoint",
UserInfoURL: "https://userinfo-endpoint",
}),
),
wantEqual: true,
},
{
name: "diff-issuer",
c1: newCfg(
Expand Down Expand Up @@ -664,6 +763,29 @@ func TestConfig_Hash(t *testing.T) {
),
wantEqual: false,
},
{
name: "diff-round-trippers",
c1: newCfg(
"https://www.alice.com",
"client-id", "client-secret",
[]Alg{RS256},
[]string{"www.alice.com/callback"},
WithScopes("email", "profile"),
WithAudiences("alice.com", "bob.com"),
WithRoundTripper(newTestRoundTripper(t)),
WithNow(time.Now),
),
c2: newCfg(
"https://www.alice.com",
"client-id", "client-secret",
[]Alg{RS256},
[]string{"www.alice.com/callback"},
WithScopes("email", "profile"),
WithAudiences("alice.com", "bob.com"),
WithNow(time.Now),
),
wantEqual: false,
},
{
name: "diff-now-func",
c1: newCfg(
Expand Down Expand Up @@ -855,3 +977,20 @@ func TestConfig_Hash(t *testing.T) {
})
}
}

type testRoundTripper struct {
transport http.RoundTripper
called int
}

func newTestRoundTripper(t *testing.T) *testRoundTripper {
t.Helper()
return &testRoundTripper{
transport: http.DefaultTransport,
}
}

func (rt *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
rt.called++
return rt.transport.RoundTrip(req)
}
4 changes: 2 additions & 2 deletions oidc/docs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func ExampleNewConfig() {
fmt.Println(pc)

// Output:
// &{your_client_id [REDACTED: client secret] [openid] http://your_issuer/ [RS256] [http://your_redirect_url/callback] [] <nil> <nil>}
// &{your_client_id [REDACTED: client secret] [openid] http://your_issuer/ [RS256] [http://your_redirect_url/callback] [] <nil> <nil> <nil>}
}

func ExampleWithProviderConfig() {
Expand All @@ -120,7 +120,7 @@ func ExampleWithProviderConfig() {
fmt.Println(string(val))

// Output:
// {"ClientID":"your_client_id","ClientSecret":"[REDACTED: client secret]","Scopes":["openid"],"Issuer":"https://your_issuer/","SupportedSigningAlgs":["RS256"],"AllowedRedirectURLs":["https://your_redirect_url/callback"],"Audiences":null,"ProviderCA":"","ProviderConfig":{"AuthURL":"https://your_issuer/authorize","TokenURL":"https://your_issuer/token","UserInfoURL":"https://your_issuer/userinfo","JWKSURL":"https://your_issuer/.well-known/jwks.json"}}
// {"ClientID":"your_client_id","ClientSecret":"[REDACTED: client secret]","Scopes":["openid"],"Issuer":"https://your_issuer/","SupportedSigningAlgs":["RS256"],"AllowedRedirectURLs":["https://your_redirect_url/callback"],"Audiences":null,"ProviderCA":"","RoundTripper":null,"ProviderConfig":{"AuthURL":"https://your_issuer/authorize","TokenURL":"https://your_issuer/token","UserInfoURL":"https://your_issuer/userinfo","JWKSURL":"https://your_issuer/.well-known/jwks.json"}}
}

func ExampleNewProvider() {
Expand Down
13 changes: 9 additions & 4 deletions oidc/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,17 +635,22 @@ func (p *Provider) HTTPClient() (*http.Client, error) {
// to the same host. On the downside, this transport can leak file
// descriptors over time, so we'll be sure to call
// client.CloseIdleConnections() in the Provider.Done() to stave that off.
tr := cleanhttp.DefaultPooledTransport()
var tr http.RoundTripper

if p.config.ProviderCA != "" {
switch {
case p.config.RoundTripper != nil && p.config.ProviderCA != "":
return nil, fmt.Errorf("%s: you cannot specify config for both a ProviderCA and RoundTripper: %w", op, ErrInvalidParameter)
case p.config.ProviderCA != "":
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(p.config.ProviderCA)); !ok {
return nil, fmt.Errorf("%s: %w", op, ErrInvalidCACert)
}

tr.TLSClientConfig = &tls.Config{
tr = cleanhttp.DefaultPooledTransport()
tr.(*http.Transport).TLSClientConfig = &tls.Config{
RootCAs: certPool,
}
case p.config.RoundTripper != nil:
tr = p.config.RoundTripper
}

c := &http.Client{
Expand Down
25 changes: 25 additions & 0 deletions oidc/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,31 @@ func TestHTTPClient(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, c.Transport, p.client.Transport)
})
t.Run("check-transport-with-round-tripper", func(t *testing.T) {
testRt := newTestRoundTripper(t)
p := &Provider{
config: &Config{
RoundTripper: testRt,
},
}
c, err := p.HTTPClient()
require.NoError(t, err)
assert.Equal(t, c.Transport, p.client.Transport)
})
t.Run("err-both-ca-and-round-trippe", func(t *testing.T) {
_, testCaPem := TestGenerateCA(t, []string{"localhost"})

p := &Provider{
config: &Config{
ProviderCA: testCaPem,
RoundTripper: newTestRoundTripper(t),
},
}
_, err := p.HTTPClient()
require.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidParameter)
assert.ErrorContains(t, err, "you cannot specify config for both a ProviderCA and RoundTripper")
})
}

func TestProvider_UserInfo(t *testing.T) {
Expand Down

0 comments on commit 3dae6e2

Please sign in to comment.