Skip to content

Commit

Permalink
feat(jwk): upgrade JWK library, ensure tokens are validated, retry on…
Browse files Browse the repository at this point in the history
… network issues (wundergraph#1488)
  • Loading branch information
StarpTech authored Jan 6, 2025
1 parent 34993f5 commit faab120
Show file tree
Hide file tree
Showing 18 changed files with 166 additions and 122 deletions.
9 changes: 5 additions & 4 deletions router-tests/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,8 @@ func TestAuthenticationWithCustomHeaders(t *testing.T) {
authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5)

tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5)
authOptions := authentication.HttpHeaderAuthenticatorOptions{
Name: jwksName,
URL: authServer.JWKSURL(),
Expand Down Expand Up @@ -733,7 +734,7 @@ func TestAuthenticationMultipleProviders(t *testing.T) {
require.NoError(t, err)
t.Cleanup(authServer2.Close)

tokenDecoder1, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer1.JWKSURL(), time.Second*5)
tokenDecoder1, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer1.JWKSURL(), time.Second*5)
authenticator1HeaderValuePrefixes := []string{"Provider1"}
authenticator1, err := authentication.NewHttpHeaderAuthenticator(authentication.HttpHeaderAuthenticatorOptions{
Name: "1",
Expand All @@ -743,7 +744,7 @@ func TestAuthenticationMultipleProviders(t *testing.T) {
})
require.NoError(t, err)

tokenDecoder2, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer2.JWKSURL(), time.Second*5)
tokenDecoder2, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer2.JWKSURL(), time.Second*5)
authenticator2HeaderValuePrefixes := []string{"", "Provider2"}
authenticator2, err := authentication.NewHttpHeaderAuthenticator(authentication.HttpHeaderAuthenticatorOptions{
Name: "2",
Expand Down Expand Up @@ -843,7 +844,7 @@ func TestAuthenticationOverWebsocket(t *testing.T) {
require.NoError(t, err)
defer authServer.Close()

tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5)
jwksOpts := authentication.HttpHeaderAuthenticatorOptions{
Name: jwksName,
URL: authServer.JWKSURL(),
Expand Down
5 changes: 3 additions & 2 deletions router-tests/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module github.com/wundergraph/cosmo/router-tests
go 1.23

require (
github.com/MicahParks/jwkset v0.5.19
github.com/buger/jsonparser v1.1.1
github.com/golang-jwt/jwt/v5 v5.2.0
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.1
github.com/hashicorp/consul/sdk v0.16.1
Expand Down Expand Up @@ -41,7 +42,7 @@ require (
dario.cat/mergo v1.0.0 // indirect
github.com/99designs/gqlgen v0.17.49 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/MicahParks/keyfunc/v2 v2.1.0 // indirect
github.com/MicahParks/keyfunc/v3 v3.3.5 // indirect
github.com/Microsoft/go-winio v0.6.1 // indirect
github.com/Microsoft/hcsshim v0.11.4 // indirect
github.com/agnivade/levenshtein v1.1.1 // indirect
Expand Down
10 changes: 6 additions & 4 deletions router-tests/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg6
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/IBM/sarama v1.42.1 h1:wugyWa15TDEHh2kvq2gAy1IHLjEjuYOYgXz/ruC/OSQ=
github.com/IBM/sarama v1.42.1/go.mod h1:Xxho9HkHd4K/MDUo/T/sOqwtX/17D33++E9Wib6hUdQ=
github.com/MicahParks/keyfunc/v2 v2.1.0 h1:6ZXKb9Rp6qp1bDbJefnG7cTH8yMN1IC/4nf+GVjO99k=
github.com/MicahParks/keyfunc/v2 v2.1.0/go.mod h1:rW42fi+xgLJ2FRRXAfNx9ZA8WpD4OeE/yHVMteCkw9k=
github.com/MicahParks/jwkset v0.5.19 h1:XZCsgJv05DBCvxEHYEHlSafqiuVn5ESG0VRB331Fxhw=
github.com/MicahParks/jwkset v0.5.19/go.mod h1:q8ptTGn/Z9c4MwbcfeCDssADeVQb3Pk7PnVxrvi+2QY=
github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo=
github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8=
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
github.com/Microsoft/hcsshim v0.11.4 h1:68vKo2VN8DE9AdN4tnkWnmdhqdbpUFM8OF3Airm7fz8=
Expand Down Expand Up @@ -132,8 +134,8 @@ github.com/goccy/go-yaml v1.13.4 h1:XOnLX9GqT+kH/gB7YzCMUiDBFU9B7pm3HZz6kyeDPkk=
github.com/goccy/go-yaml v1.13.4/go.mod h1:IjYwxUiJDoqpx2RmbdjMUceGHZwYLon3sfOGl5Hi9lc=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
Expand Down
85 changes: 44 additions & 41 deletions router-tests/jwks/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"math/big"
"github.com/MicahParks/jwkset"
"log"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -16,8 +15,7 @@ import (
)

const (
jwtKeyID = "123456789"
signingMethodType = "RSA" // This should match signingMethod below
jwtKeyID = "123456789"

jwksHTTPPath = "/.well-known/jwks.json"
)
Expand All @@ -29,6 +27,7 @@ var (
type Server struct {
privateKey *rsa.PrivateKey
httpServer *httptest.Server
storage jwkset.Storage
}

func (s *Server) Close() {
Expand All @@ -37,45 +36,18 @@ func (s *Server) Close() {

func (s *Server) Token(claims map[string]any) (string, error) {
token := jwt.NewWithClaims(signingMethod, jwt.MapClaims(claims))
token.Header["kid"] = jwtKeyID
token.Header[jwkset.HeaderKID] = jwtKeyID
return token.SignedString(s.privateKey)
}

type jsonWebKeySet struct {
Keys []jsonWebKey `json:"keys"`
}

type jsonWebKey struct {
Algorithm string `json:"alg"`
Curve string `json:"crv"`
Exponent string `json:"e"`
K string `json:"k"`
ID string `json:"kid"`
Modulus string `json:"n"`
Type string `json:"kty"`
Use string `json:"use"`
X string `json:"x"`
Y string `json:"y"`
}

func (s *Server) jwksJSON(w http.ResponseWriter, r *http.Request) {
k := jsonWebKey{
Type: signingMethodType,
Algorithm: signingMethod.Name,
Use: "sig",
ID: jwtKeyID,
Exponent: base64.URLEncoding.EncodeToString(big.NewInt(int64(s.privateKey.E)).Bytes()),
Modulus: base64.URLEncoding.EncodeToString(s.privateKey.N.Bytes()),
}
data, err := json.Marshal(jsonWebKeySet{Keys: []jsonWebKey{k}})
if err != nil {
panic(err)
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(data)
ctx := context.Background()

rawJWKS, err := s.storage.JSONPublic(ctx)
if err != nil {
panic(err)
log.Fatalf("Failed to get the server's JWKS.\nError: %s", err)
}
_, _ = w.Write(rawJWKS)
}

func (s *Server) JWKSURL() string {
Expand All @@ -99,12 +71,43 @@ func (s *Server) waitForServer(ctx context.Context) error {
}

func NewServer(t *testing.T) (*Server, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
ctx := context.Background()

// Create a cryptographic key.
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate given key.\nError: %s", err)
}

// Turn the key into a JWK.
marshalOptions := jwkset.JWKMarshalOptions{
Private: true,
}
metadata := jwkset.JWKMetadataOptions{
ALG: jwkset.AlgRS256,
KID: jwtKeyID,
USE: jwkset.UseSig,
}
options := jwkset.JWKOptions{
Marshal: marshalOptions,
Metadata: metadata,
}

jwk, err := jwkset.NewJWKFromKey(priv, options)
if err != nil {
return nil, err
t.Fatalf("Failed to create a JWK from the given key.\nError: %s", err)
}

// Write the JWK to the server's storage.
serverStore := jwkset.NewMemoryStorage()
err = serverStore.KeyWrite(ctx, jwk)
if err != nil {
t.Fatalf("Failed to write the JWK to the server's storage.\nError: %s", err)
}

s := &Server{
privateKey: privateKey,
privateKey: priv,
storage: serverStore,
}
mux := http.NewServeMux()
mux.HandleFunc(jwksHTTPPath, s.jwksJSON)
Expand Down
3 changes: 2 additions & 1 deletion router-tests/modules/set_scopes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package module_test

import (
"github.com/stretchr/testify/require"
integration "github.com/wundergraph/cosmo/router-tests"
"github.com/wundergraph/cosmo/router-tests/jwks"
setScopesModule "github.com/wundergraph/cosmo/router-tests/modules/custom-set-scopes"
"github.com/wundergraph/cosmo/router-tests/testenv"
Expand All @@ -27,7 +28,7 @@ func configureAuth(t *testing.T) ([]authentication.Authenticator, *jwks.Server)
authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(integration.NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5)
authOptions := authentication.HttpHeaderAuthenticatorOptions{
Name: jwksName,
URL: authServer.JWKSURL(),
Expand Down
10 changes: 9 additions & 1 deletion router-tests/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package integration

import (
"context"
"github.com/stretchr/testify/require"
"github.com/wundergraph/cosmo/router-tests/jwks"
"github.com/wundergraph/cosmo/router/pkg/authentication"
Expand All @@ -15,6 +16,13 @@ const (
jwksName = "my-jwks-server"
)

// NewContextWithCancel creates a new context with a cancel function that is called when the test is done.
func NewContextWithCancel(t *testing.T) context.Context {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
return ctx
}

func RequireSpanWithName(t *testing.T, exporter *tracetest2.InMemoryExporter, name string) trace.ReadOnlySpan {
require.NotNil(t, exporter)
require.NotNil(t, exporter.GetSpans())
Expand All @@ -35,7 +43,7 @@ func configureAuth(t *testing.T) ([]authentication.Authenticator, *jwks.Server)
authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5)
authOptions := authentication.HttpHeaderAuthenticatorOptions{
Name: jwksName,
URL: authServer.JWKSURL(),
Expand Down
14 changes: 7 additions & 7 deletions router-tests/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestWebSockets(t *testing.T) {
authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5)
authOptions := authentication.HttpHeaderAuthenticatorOptions{
Name: jwksName,
URL: authServer.JWKSURL(),
Expand Down Expand Up @@ -125,7 +125,7 @@ func TestWebSockets(t *testing.T) {
authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5)
authOptions := authentication.HttpHeaderAuthenticatorOptions{
Name: jwksName,
URL: authServer.JWKSURL(),
Expand Down Expand Up @@ -175,7 +175,7 @@ func TestWebSockets(t *testing.T) {
authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5)
authOptions := authentication.HttpHeaderAuthenticatorOptions{
Name: jwksName,
URL: authServer.JWKSURL(),
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestWebSockets(t *testing.T) {
authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5)
authOptions := authentication.HttpHeaderAuthenticatorOptions{
Name: jwksName,
URL: authServer.JWKSURL(),
Expand Down Expand Up @@ -292,7 +292,7 @@ func TestWebSockets(t *testing.T) {
authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5)
authOptions := authentication.WebsocketInitialPayloadAuthenticatorOptions{
TokenDecoder: tokenDecoder,
Key: "Authorization",
Expand Down Expand Up @@ -353,7 +353,7 @@ func TestWebSockets(t *testing.T) {
authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5)
authOptions := authentication.WebsocketInitialPayloadAuthenticatorOptions{
TokenDecoder: tokenDecoder,
Key: "Authorization",
Expand Down Expand Up @@ -402,7 +402,7 @@ func TestWebSockets(t *testing.T) {
authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5)
tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), authServer.JWKSURL(), time.Second*5)
authOptions := authentication.WebsocketInitialPayloadAuthenticatorOptions{
TokenDecoder: tokenDecoder,
Key: "Authorization",
Expand Down
5 changes: 3 additions & 2 deletions router/cmd/instance.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"fmt"
"net/http"
"os"
Expand All @@ -26,7 +27,7 @@ type Params struct {
// NewRouter creates a new router instance.
//
// additionalOptions can be used to override default options or options provided in the config.
func NewRouter(params Params, additionalOptions ...core.Option) (*core.Router, error) {
func NewRouter(ctx context.Context, params Params, additionalOptions ...core.Option) (*core.Router, error) {
// Automatically set GOMAXPROCS to avoid CPU throttling on containerized environments
_, err := maxprocs.Set(maxprocs.Logger(params.Logger.Sugar().Debugf))
if err != nil {
Expand Down Expand Up @@ -62,7 +63,7 @@ func NewRouter(params Params, additionalOptions ...core.Option) (*core.Router, e
name = fmt.Sprintf("jwks-#%d", i)
}
providerLogger := logger.With(zap.String("provider_name", name))
tokenDecoder, err := authentication.NewJwksTokenDecoder(providerLogger, auth.JWKS.URL, auth.JWKS.RefreshInterval)
tokenDecoder, err := authentication.NewJwksTokenDecoder(ctx, providerLogger, auth.JWKS.URL, auth.JWKS.RefreshInterval)
if err != nil {
providerLogger.Error("Could not create JWKS token decoder", zap.Error(err))
return nil, err
Expand Down
13 changes: 6 additions & 7 deletions router/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,19 @@ func Main() {
)
}

router, err := NewRouter(Params{
// Provide a way to cancel all running components of the router after graceful shutdown
// Don't use the parent context that is canceled by the signal handler
routerCtx, routerCancel := context.WithCancel(context.Background())
defer routerCancel()

router, err := NewRouter(routerCtx, Params{
Config: &result.Config,
Logger: logger,
})

if err != nil {
logger.Fatal("Could not create router", zap.Error(err))
}

// Provide a way to cancel all running components of the router after graceful shutdown
// Don't use the parent context that is canceled by the signal handler
routerCtx, routerCancel := context.WithCancel(context.Background())
defer routerCancel()

if err = router.Start(routerCtx); err != nil {
logger.Fatal("Could not start router", zap.Error(err))
}
Expand Down
8 changes: 0 additions & 8 deletions router/core/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -1305,14 +1305,6 @@ func (r *Router) Shutdown(ctx context.Context) (err error) {
r.persistedOperationClient.Close()
}

if r.accessController != nil {
for _, authenticator := range r.accessController.authenticators {
if authenticator != nil {
authenticator.Close()
}
}
}

wg.Wait()

return err
Expand Down
Loading

0 comments on commit faab120

Please sign in to comment.