From faab1205f1b9f1f8684d26c23191be14514c1cf8 Mon Sep 17 00:00:00 2001 From: Dustin Deus Date: Mon, 6 Jan 2025 11:54:36 +0100 Subject: [PATCH] feat(jwk): upgrade JWK library, ensure tokens are validated, retry on network issues (#1488) --- router-tests/authentication_test.go | 9 +- router-tests/go.mod | 5 +- router-tests/go.sum | 10 ++- router-tests/jwks/jwks.go | 85 ++++++++++--------- router-tests/modules/set_scopes_test.go | 3 +- router-tests/utils.go | 10 ++- router-tests/websocket_test.go | 14 +-- router/cmd/instance.go | 5 +- router/cmd/main.go | 13 ++- router/core/router.go | 8 -- router/go.mod | 6 +- router/go.sum | 12 ++- router/internal/httpclient/retryableclient.go | 6 +- .../operationstorage/cdn/client.go | 5 +- router/pkg/authentication/authentication.go | 1 - .../http_header_authenticator.go | 6 -- .../initial_payload_authenticator.go | 6 -- .../pkg/authentication/jwks_token_decoder.go | 84 ++++++++++++------ 18 files changed, 166 insertions(+), 122 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index fde9856ac1..fd0c4bfdc9 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -598,7 +598,8 @@ func TestAuthenticationWithCustomHeaders(t *testing.T) { authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + + tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5) authOptions := authentication.HttpHeaderAuthenticatorOptions{ Name: jwksName, URL: authServer.JWKSURL(), @@ -733,7 +734,7 @@ func TestAuthenticationMultipleProviders(t *testing.T) { require.NoError(t, err) t.Cleanup(authServer2.Close) - tokenDecoder1, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer1.JWKSURL(), time.Second*5) + tokenDecoder1, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer1.JWKSURL(), time.Second*5) authenticator1HeaderValuePrefixes := []string{"Provider1"} authenticator1, err := authentication.NewHttpHeaderAuthenticator(authentication.HttpHeaderAuthenticatorOptions{ Name: "1", @@ -743,7 +744,7 @@ func TestAuthenticationMultipleProviders(t *testing.T) { }) require.NoError(t, err) - tokenDecoder2, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer2.JWKSURL(), time.Second*5) + tokenDecoder2, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer2.JWKSURL(), time.Second*5) authenticator2HeaderValuePrefixes := []string{"", "Provider2"} authenticator2, err := authentication.NewHttpHeaderAuthenticator(authentication.HttpHeaderAuthenticatorOptions{ Name: "2", @@ -843,7 +844,7 @@ func TestAuthenticationOverWebsocket(t *testing.T) { require.NoError(t, err) defer authServer.Close() - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5) jwksOpts := authentication.HttpHeaderAuthenticatorOptions{ Name: jwksName, URL: authServer.JWKSURL(), diff --git a/router-tests/go.mod b/router-tests/go.mod index 5be7b94560..906b46aa4b 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -3,8 +3,9 @@ module github.com/wundergraph/cosmo/router-tests go 1.23 require ( + github.com/MicahParks/jwkset v0.5.19 github.com/buger/jsonparser v1.1.1 - github.com/golang-jwt/jwt/v5 v5.2.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.1 github.com/hashicorp/consul/sdk v0.16.1 @@ -41,7 +42,7 @@ require ( dario.cat/mergo v1.0.0 // indirect github.com/99designs/gqlgen v0.17.49 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect - github.com/MicahParks/keyfunc/v2 v2.1.0 // indirect + github.com/MicahParks/keyfunc/v3 v3.3.5 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect github.com/Microsoft/hcsshim v0.11.4 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index 387f55253f..f53196ad0e 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -11,8 +11,10 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg6 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/IBM/sarama v1.42.1 h1:wugyWa15TDEHh2kvq2gAy1IHLjEjuYOYgXz/ruC/OSQ= github.com/IBM/sarama v1.42.1/go.mod h1:Xxho9HkHd4K/MDUo/T/sOqwtX/17D33++E9Wib6hUdQ= -github.com/MicahParks/keyfunc/v2 v2.1.0 h1:6ZXKb9Rp6qp1bDbJefnG7cTH8yMN1IC/4nf+GVjO99k= -github.com/MicahParks/keyfunc/v2 v2.1.0/go.mod h1:rW42fi+xgLJ2FRRXAfNx9ZA8WpD4OeE/yHVMteCkw9k= +github.com/MicahParks/jwkset v0.5.19 h1:XZCsgJv05DBCvxEHYEHlSafqiuVn5ESG0VRB331Fxhw= +github.com/MicahParks/jwkset v0.5.19/go.mod h1:q8ptTGn/Z9c4MwbcfeCDssADeVQb3Pk7PnVxrvi+2QY= +github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= +github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/Microsoft/hcsshim v0.11.4 h1:68vKo2VN8DE9AdN4tnkWnmdhqdbpUFM8OF3Airm7fz8= @@ -132,8 +134,8 @@ github.com/goccy/go-yaml v1.13.4 h1:XOnLX9GqT+kH/gB7YzCMUiDBFU9B7pm3HZz6kyeDPkk= github.com/goccy/go-yaml v1.13.4/go.mod h1:IjYwxUiJDoqpx2RmbdjMUceGHZwYLon3sfOGl5Hi9lc= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= -github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index eec8ad04be..7ab835ee99 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -4,9 +4,8 @@ import ( "context" "crypto/rand" "crypto/rsa" - "encoding/base64" - "encoding/json" - "math/big" + "github.com/MicahParks/jwkset" + "log" "net/http" "net/http/httptest" "testing" @@ -16,8 +15,7 @@ import ( ) const ( - jwtKeyID = "123456789" - signingMethodType = "RSA" // This should match signingMethod below + jwtKeyID = "123456789" jwksHTTPPath = "/.well-known/jwks.json" ) @@ -29,6 +27,7 @@ var ( type Server struct { privateKey *rsa.PrivateKey httpServer *httptest.Server + storage jwkset.Storage } func (s *Server) Close() { @@ -37,45 +36,18 @@ func (s *Server) Close() { func (s *Server) Token(claims map[string]any) (string, error) { token := jwt.NewWithClaims(signingMethod, jwt.MapClaims(claims)) - token.Header["kid"] = jwtKeyID + token.Header[jwkset.HeaderKID] = jwtKeyID return token.SignedString(s.privateKey) } -type jsonWebKeySet struct { - Keys []jsonWebKey `json:"keys"` -} - -type jsonWebKey struct { - Algorithm string `json:"alg"` - Curve string `json:"crv"` - Exponent string `json:"e"` - K string `json:"k"` - ID string `json:"kid"` - Modulus string `json:"n"` - Type string `json:"kty"` - Use string `json:"use"` - X string `json:"x"` - Y string `json:"y"` -} - func (s *Server) jwksJSON(w http.ResponseWriter, r *http.Request) { - k := jsonWebKey{ - Type: signingMethodType, - Algorithm: signingMethod.Name, - Use: "sig", - ID: jwtKeyID, - Exponent: base64.URLEncoding.EncodeToString(big.NewInt(int64(s.privateKey.E)).Bytes()), - Modulus: base64.URLEncoding.EncodeToString(s.privateKey.N.Bytes()), - } - data, err := json.Marshal(jsonWebKeySet{Keys: []jsonWebKey{k}}) - if err != nil { - panic(err) - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(data) + ctx := context.Background() + + rawJWKS, err := s.storage.JSONPublic(ctx) if err != nil { - panic(err) + log.Fatalf("Failed to get the server's JWKS.\nError: %s", err) } + _, _ = w.Write(rawJWKS) } func (s *Server) JWKSURL() string { @@ -99,12 +71,43 @@ func (s *Server) waitForServer(ctx context.Context) error { } func NewServer(t *testing.T) (*Server, error) { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + ctx := context.Background() + + // Create a cryptographic key. + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate given key.\nError: %s", err) + } + + // Turn the key into a JWK. + marshalOptions := jwkset.JWKMarshalOptions{ + Private: true, + } + metadata := jwkset.JWKMetadataOptions{ + ALG: jwkset.AlgRS256, + KID: jwtKeyID, + USE: jwkset.UseSig, + } + options := jwkset.JWKOptions{ + Marshal: marshalOptions, + Metadata: metadata, + } + + jwk, err := jwkset.NewJWKFromKey(priv, options) if err != nil { - return nil, err + t.Fatalf("Failed to create a JWK from the given key.\nError: %s", err) } + + // Write the JWK to the server's storage. + serverStore := jwkset.NewMemoryStorage() + err = serverStore.KeyWrite(ctx, jwk) + if err != nil { + t.Fatalf("Failed to write the JWK to the server's storage.\nError: %s", err) + } + s := &Server{ - privateKey: privateKey, + privateKey: priv, + storage: serverStore, } mux := http.NewServeMux() mux.HandleFunc(jwksHTTPPath, s.jwksJSON) diff --git a/router-tests/modules/set_scopes_test.go b/router-tests/modules/set_scopes_test.go index feceefebbb..278ba4f2e9 100644 --- a/router-tests/modules/set_scopes_test.go +++ b/router-tests/modules/set_scopes_test.go @@ -2,6 +2,7 @@ package module_test import ( "github.com/stretchr/testify/require" + integration "github.com/wundergraph/cosmo/router-tests" "github.com/wundergraph/cosmo/router-tests/jwks" setScopesModule "github.com/wundergraph/cosmo/router-tests/modules/custom-set-scopes" "github.com/wundergraph/cosmo/router-tests/testenv" @@ -27,7 +28,7 @@ func configureAuth(t *testing.T) ([]authentication.Authenticator, *jwks.Server) authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(integration.NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5) authOptions := authentication.HttpHeaderAuthenticatorOptions{ Name: jwksName, URL: authServer.JWKSURL(), diff --git a/router-tests/utils.go b/router-tests/utils.go index 874c7052fc..fa06efc2ab 100644 --- a/router-tests/utils.go +++ b/router-tests/utils.go @@ -1,6 +1,7 @@ package integration import ( + "context" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/jwks" "github.com/wundergraph/cosmo/router/pkg/authentication" @@ -15,6 +16,13 @@ const ( jwksName = "my-jwks-server" ) +// NewContextWithCancel creates a new context with a cancel function that is called when the test is done. +func NewContextWithCancel(t *testing.T) context.Context { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + return ctx +} + func RequireSpanWithName(t *testing.T, exporter *tracetest2.InMemoryExporter, name string) trace.ReadOnlySpan { require.NotNil(t, exporter) require.NotNil(t, exporter.GetSpans()) @@ -35,7 +43,7 @@ func configureAuth(t *testing.T) ([]authentication.Authenticator, *jwks.Server) authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5) authOptions := authentication.HttpHeaderAuthenticatorOptions{ Name: jwksName, URL: authServer.JWKSURL(), diff --git a/router-tests/websocket_test.go b/router-tests/websocket_test.go index d5fd5c1ca0..57d384a988 100644 --- a/router-tests/websocket_test.go +++ b/router-tests/websocket_test.go @@ -75,7 +75,7 @@ func TestWebSockets(t *testing.T) { authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5) authOptions := authentication.HttpHeaderAuthenticatorOptions{ Name: jwksName, URL: authServer.JWKSURL(), @@ -125,7 +125,7 @@ func TestWebSockets(t *testing.T) { authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5) authOptions := authentication.HttpHeaderAuthenticatorOptions{ Name: jwksName, URL: authServer.JWKSURL(), @@ -175,7 +175,7 @@ func TestWebSockets(t *testing.T) { authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5) authOptions := authentication.HttpHeaderAuthenticatorOptions{ Name: jwksName, URL: authServer.JWKSURL(), @@ -234,7 +234,7 @@ func TestWebSockets(t *testing.T) { authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5) authOptions := authentication.HttpHeaderAuthenticatorOptions{ Name: jwksName, URL: authServer.JWKSURL(), @@ -292,7 +292,7 @@ func TestWebSockets(t *testing.T) { authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5) authOptions := authentication.WebsocketInitialPayloadAuthenticatorOptions{ TokenDecoder: tokenDecoder, Key: "Authorization", @@ -353,7 +353,7 @@ func TestWebSockets(t *testing.T) { authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5) authOptions := authentication.WebsocketInitialPayloadAuthenticatorOptions{ TokenDecoder: tokenDecoder, Key: "Authorization", @@ -402,7 +402,7 @@ func TestWebSockets(t *testing.T) { authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5) authOptions := authentication.WebsocketInitialPayloadAuthenticatorOptions{ TokenDecoder: tokenDecoder, Key: "Authorization", diff --git a/router/cmd/instance.go b/router/cmd/instance.go index 27a1fbb4d5..051db6a546 100644 --- a/router/cmd/instance.go +++ b/router/cmd/instance.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "net/http" "os" @@ -26,7 +27,7 @@ type Params struct { // NewRouter creates a new router instance. // // additionalOptions can be used to override default options or options provided in the config. -func NewRouter(params Params, additionalOptions ...core.Option) (*core.Router, error) { +func NewRouter(ctx context.Context, params Params, additionalOptions ...core.Option) (*core.Router, error) { // Automatically set GOMAXPROCS to avoid CPU throttling on containerized environments _, err := maxprocs.Set(maxprocs.Logger(params.Logger.Sugar().Debugf)) if err != nil { @@ -62,7 +63,7 @@ func NewRouter(params Params, additionalOptions ...core.Option) (*core.Router, e name = fmt.Sprintf("jwks-#%d", i) } providerLogger := logger.With(zap.String("provider_name", name)) - tokenDecoder, err := authentication.NewJwksTokenDecoder(providerLogger, auth.JWKS.URL, auth.JWKS.RefreshInterval) + tokenDecoder, err := authentication.NewJwksTokenDecoder(ctx, providerLogger, auth.JWKS.URL, auth.JWKS.RefreshInterval) if err != nil { providerLogger.Error("Could not create JWKS token decoder", zap.Error(err)) return nil, err diff --git a/router/cmd/main.go b/router/cmd/main.go index 1d7d8962c6..784cf5b67b 100644 --- a/router/cmd/main.go +++ b/router/cmd/main.go @@ -90,20 +90,19 @@ func Main() { ) } - router, err := NewRouter(Params{ + // Provide a way to cancel all running components of the router after graceful shutdown + // Don't use the parent context that is canceled by the signal handler + routerCtx, routerCancel := context.WithCancel(context.Background()) + defer routerCancel() + + router, err := NewRouter(routerCtx, Params{ Config: &result.Config, Logger: logger, }) - if err != nil { logger.Fatal("Could not create router", zap.Error(err)) } - // Provide a way to cancel all running components of the router after graceful shutdown - // Don't use the parent context that is canceled by the signal handler - routerCtx, routerCancel := context.WithCancel(context.Background()) - defer routerCancel() - if err = router.Start(routerCtx); err != nil { logger.Fatal("Could not start router", zap.Error(err)) } diff --git a/router/core/router.go b/router/core/router.go index 61c6f4a2f8..b7b9642f11 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -1305,14 +1305,6 @@ func (r *Router) Shutdown(ctx context.Context) (err error) { r.persistedOperationClient.Close() } - if r.accessController != nil { - for _, authenticator := range r.accessController.authenticators { - if authenticator != nil { - authenticator.Close() - } - } - } - wg.Wait() return err diff --git a/router/go.mod b/router/go.mod index c8271b6cff..147ccdca55 100644 --- a/router/go.mod +++ b/router/go.mod @@ -4,7 +4,6 @@ go 1.23 require ( connectrpc.com/connect v1.16.2 - github.com/MicahParks/keyfunc/v2 v2.1.0 github.com/andybalholm/brotli v1.1.0 // indirect github.com/buger/jsonparser v1.1.1 github.com/cespare/xxhash/v2 v2.3.0 @@ -17,7 +16,7 @@ require ( github.com/go-redis/redis_rate/v10 v10.0.1 github.com/gobwas/ws v1.4.0 github.com/goccy/go-yaml v1.13.4 - github.com/golang-jwt/jwt/v5 v5.2.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/gorilla/websocket v1.5.1 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-retryablehttp v0.7.7 @@ -62,6 +61,8 @@ require ( require ( github.com/KimMachineGun/automemlimit v0.6.1 + github.com/MicahParks/jwkset v0.5.19 + github.com/MicahParks/keyfunc/v3 v3.3.5 github.com/bep/debounce v1.2.1 github.com/caarlos0/env/v11 v11.1.0 github.com/expr-lang/expr v1.16.9 @@ -75,6 +76,7 @@ require ( go.uber.org/ratelimit v0.3.1 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/text v0.21.0 + golang.org/x/time v0.5.0 ) require ( diff --git a/router/go.sum b/router/go.sum index f187fb498d..21c67f20ce 100644 --- a/router/go.sum +++ b/router/go.sum @@ -5,8 +5,10 @@ github.com/99designs/gqlgen v0.17.49/go.mod h1:tC8YFVZMed81x7UJ7ORUwXF4Kn6SXuucF github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26yLj/V+ulKp8= github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= -github.com/MicahParks/keyfunc/v2 v2.1.0 h1:6ZXKb9Rp6qp1bDbJefnG7cTH8yMN1IC/4nf+GVjO99k= -github.com/MicahParks/keyfunc/v2 v2.1.0/go.mod h1:rW42fi+xgLJ2FRRXAfNx9ZA8WpD4OeE/yHVMteCkw9k= +github.com/MicahParks/jwkset v0.5.19 h1:XZCsgJv05DBCvxEHYEHlSafqiuVn5ESG0VRB331Fxhw= +github.com/MicahParks/jwkset v0.5.19/go.mod h1:q8ptTGn/Z9c4MwbcfeCDssADeVQb3Pk7PnVxrvi+2QY= +github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= +github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= @@ -98,8 +100,8 @@ github.com/goccy/go-yaml v1.13.4/go.mod h1:IjYwxUiJDoqpx2RmbdjMUceGHZwYLon3sfOGl github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= -github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -358,6 +360,8 @@ golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/router/internal/httpclient/retryableclient.go b/router/internal/httpclient/retryableclient.go index 58a9030486..ea27d2fe8a 100644 --- a/router/internal/httpclient/retryableclient.go +++ b/router/internal/httpclient/retryableclient.go @@ -9,10 +9,14 @@ import ( func NewRetryableHTTPClient(logger *zap.Logger) *http.Client { retryClient := retryablehttp.NewClient() - retryClient.RetryWaitMax = 60 * time.Second + retryClient.RetryWaitMax = 30 * time.Second retryClient.RetryMax = 5 retryClient.Backoff = retryablehttp.DefaultBackoff retryClient.Logger = nil + retryClient.ErrorHandler = func(resp *http.Response, err error, numTries int) (*http.Response, error) { + logger.Error("Request failed", zap.Error(err), zap.Int("numTries", numTries)) + return resp, err + } retryClient.RequestLogHook = func(_ retryablehttp.Logger, _ *http.Request, retry int) { if retry > 0 { logger.Info("Retry request", zap.Int("retry", retry)) diff --git a/router/internal/persistedoperation/operationstorage/cdn/client.go b/router/internal/persistedoperation/operationstorage/cdn/client.go index 833f4b7ef4..fe1648e516 100644 --- a/router/internal/persistedoperation/operationstorage/cdn/client.go +++ b/router/internal/persistedoperation/operationstorage/cdn/client.go @@ -139,7 +139,10 @@ func NewClient(endpoint string, token string, opts Options) (persistedoperation. return nil, err } - logger := opts.Logger.With(zap.String("component", "persisted_operations_client")) + logger := opts.Logger.With( + zap.String("component", "persisted_operations_client"), + zap.String("url", endpoint), + ) return &client{ cdnURL: u, diff --git a/router/pkg/authentication/authentication.go b/router/pkg/authentication/authentication.go index 9c19cade57..5adbf6075d 100644 --- a/router/pkg/authentication/authentication.go +++ b/router/pkg/authentication/authentication.go @@ -22,7 +22,6 @@ type Provider interface { type Authenticator interface { Name() string Authenticate(ctx context.Context, p Provider) (Claims, error) - Close() } type Authentication interface { diff --git a/router/pkg/authentication/http_header_authenticator.go b/router/pkg/authentication/http_header_authenticator.go index 7ee493ea8e..7fbd13b45c 100644 --- a/router/pkg/authentication/http_header_authenticator.go +++ b/router/pkg/authentication/http_header_authenticator.go @@ -19,12 +19,6 @@ type httpHeaderAuthenticator struct { headerValuePrefixes []string } -func (a *httpHeaderAuthenticator) Close() { - if a.tokenDecoder != nil { - a.tokenDecoder.Close() - } -} - func (a *httpHeaderAuthenticator) Name() string { return a.name } diff --git a/router/pkg/authentication/initial_payload_authenticator.go b/router/pkg/authentication/initial_payload_authenticator.go index f5ba1bc4de..3a25bcda91 100644 --- a/router/pkg/authentication/initial_payload_authenticator.go +++ b/router/pkg/authentication/initial_payload_authenticator.go @@ -15,12 +15,6 @@ type websocketInitialPayloadAuthenticator struct { headerValuePrefixes []string } -func (a *websocketInitialPayloadAuthenticator) Close() { - if a.tokenDecoder != nil { - a.tokenDecoder.Close() - } -} - func (a *websocketInitialPayloadAuthenticator) Name() string { return a.name } diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 7b42aea4c2..e845805b97 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -1,23 +1,26 @@ package authentication import ( + "context" "fmt" + "github.com/MicahParks/jwkset" + "github.com/MicahParks/keyfunc/v3" + "github.com/wundergraph/cosmo/router/internal/httpclient" "go.uber.org/zap" + "golang.org/x/time/rate" + "net/http" + "net/url" "time" - "github.com/MicahParks/keyfunc/v2" "github.com/golang-jwt/jwt/v5" ) type TokenDecoder interface { Decode(token string) (Claims, error) - Close() } type jwksTokenDecoder struct { - // JSON Web Key Set, automatically updated in the background - // by keyfunc. - jwks *keyfunc.JWKS + jwks keyfunc.Keyfunc } // Decode implements TokenDecoder. @@ -26,35 +29,68 @@ func (j *jwksTokenDecoder) Decode(tokenString string) (Claims, error) { if err != nil { return nil, fmt.Errorf("could not validate token: %w", err) } + + if !token.Valid { + return nil, fmt.Errorf("token is invalid") + } + claims := token.Claims.(jwt.MapClaims) return Claims(claims), nil } -func NewJwksTokenDecoder(logger *zap.Logger, url string, refreshInterval time.Duration) (TokenDecoder, error) { +func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, u string, refreshInterval time.Duration) (TokenDecoder, error) { - jwks, err := keyfunc.Get(url, keyfunc.Options{ - RefreshInterval: refreshInterval, - // Allow the JWKS to be empty initially, but it can recover on refresh. - TolerateInitialJWKHTTPError: true, - RefreshErrorHandler: func(err error) { - logger.Error("Could not refresh JWKS. Trying again in the next interval.", - zap.Error(err), - zap.String("url", url), - zap.String("interval", refreshInterval.String()), - ) + logger = logger.With(zap.String("url", u)) + + // Create the JWK Set HTTP client. + remoteJWKSets := make(map[string]jwkset.Storage) + + ur, err := url.ParseRequestURI(u) + if err != nil { + return nil, fmt.Errorf("failed to parse given URL %q: %w", u, err) + } + jwksetHTTPStorageOptions := jwkset.HTTPClientStorageOptions{ + Client: httpclient.NewRetryableHTTPClient(logger), + Ctx: ctx, // Used to end background refresh goroutine. + HTTPExpectedStatus: http.StatusOK, + HTTPMethod: http.MethodGet, + HTTPTimeout: 15 * time.Second, + RefreshErrorHandler: func(ctx context.Context, err error) { + logger.Error("Failed to refresh HTTP JWK Set from remote HTTP resource.", zap.Error(err)) }, - }) + RefreshInterval: refreshInterval, + Storage: nil, + } + store, err := jwkset.NewStorageFromHTTP(ur, jwksetHTTPStorageOptions) if err != nil { - return nil, fmt.Errorf("error initializing JWKS from %q: %w", url, err) + return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err) + } + + remoteJWKSets[ur.String()] = store + + // Create the JWK Set containing HTTP clients and given keys. + jwksetHTTPClientOptions := jwkset.HTTPClientOptions{ + HTTPURLs: remoteJWKSets, + PrioritizeHTTP: false, + RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1), + } + combined, err := jwkset.NewHTTPClient(jwksetHTTPClientOptions) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err) + } + + keyfuncOptions := keyfunc.Options{ + Ctx: ctx, + Storage: combined, + UseWhitelist: []jwkset.USE{jwkset.UseSig}, + } + + jwks, err := keyfunc.New(keyfuncOptions) + if err != nil { + return nil, fmt.Errorf("error initializing JWK: %w", err) } return &jwksTokenDecoder{ jwks: jwks, }, nil } - -func (j *jwksTokenDecoder) Close() { - if j.jwks != nil { - j.jwks.EndBackground() - } -}