-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathjwks.go
169 lines (139 loc) · 4.75 KB
/
jwks.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
// Package jwks provides helpers for working with json key sets.
package jwks
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/lestrrat-go/jwx/jwk"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// KeySet represents json key set object, a collection of jwk.Key objects.
// See jwk docs. github.com/lestrrat-go/jwx/jwk.
type KeySet jwk.Set
// KeyProvider provides an interface to lookup keys based on a key ID.
// Providers may have a background process to refresh keys and allows
// it to be closed.
type KeyProvider interface {
// allow users to stop any background process in a key provider.
io.Closer
// LookupKey should return a public key based on the given key ID. Return an error if not
// found or any other error.
LookupKey(ctx context.Context, kid, alg string) (interface{}, error)
// Fetch returns the full KeySet as a cloned keyset, any modifcations are only applied locally.
Fetch(ctx context.Context) (KeySet, error)
}
// ParseKeySet parses a JSON keyset string into a KeySet.
func ParseKeySet(input string) (KeySet, error) {
return jwk.ParseString(input)
}
// cachingKeyProvider is a key provider that looks up jwk's by their kid through the
// configured jwksURI. It auto refreshes in the background and caches the keys found.
type cachingKeyProvider struct {
cancel context.CancelFunc
ar *jwk.AutoRefresh
jwksURI string
}
// Stop cancels the auto refresh.
func (cp *cachingKeyProvider) Close() error {
cp.cancel()
return nil
}
func (cp *cachingKeyProvider) LookupKey(ctx context.Context, kid, alg string) (interface{}, error) {
// loads keys from cache or refreshes if needed.
keyset, err := cp.ar.Fetch(ctx, cp.jwksURI)
if err != nil {
return nil, err
}
return publicKeyFromKeySet(keyset, kid, alg)
}
func (cp *cachingKeyProvider) Fetch(ctx context.Context) (KeySet, error) {
// loads keys from cache or refreshes if needed.
keyset, err := cp.ar.Fetch(ctx, cp.jwksURI)
if err != nil {
return nil, err
}
return keyset.Clone()
}
// ensure interface is met.
var _ KeyProvider = &cachingKeyProvider{}
// NewCachingOIDCJWKKeyProvider creates a CachingKeyProvider based on the issuer url
// base domain and starts the auto refresh. Call CachingKeyProvider.Stop() to stop any
// background goroutines.
func NewCachingOIDCJWKKeyProvider(ctx context.Context, issuer string) (KeyProvider, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpClient := &http.Client{
Transport: httpTransport,
}
defer httpTransport.CloseIdleConnections()
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
if err != nil {
return nil, err
}
discoveryConfig := new(oidc.DiscoveryConfiguration)
err = httphelper.HttpRequest(httpClient, req, &discoveryConfig)
if err != nil {
return nil, err
}
if discoveryConfig.Issuer != issuer {
return nil, oidc.ErrIssuerInvalid
}
ctx, cancel := context.WithCancel(ctx)
ar := jwk.NewAutoRefresh(ctx)
// Tell *jwk.AutoRefresh that we only want to refresh this JWKS
// when it needs to (based on Cache-Control or Expires header from
// the HTTP response). If the calculated minimum refresh interval is less
// than 15 minutes, don't go refreshing any earlier than 15 minutes.
ar.Configure(discoveryConfig.JwksURI, jwk.WithMinRefreshInterval(15*time.Minute))
// Refresh the JWKS once before we start our service.
if _, err := ar.Refresh(ctx, discoveryConfig.JwksURI); err != nil {
cancel()
return nil, err
}
return &cachingKeyProvider{
cancel: cancel,
ar: ar,
jwksURI: discoveryConfig.JwksURI,
}, nil
}
// wraps a static KeySet.
type staticKeySet struct {
keyset KeySet
}
// ensure interface is met.
var _ KeyProvider = &staticKeySet{}
func (p *staticKeySet) LookupKey(ctx context.Context, kid, alg string) (interface{}, error) {
return publicKeyFromKeySet(p.keyset, kid, alg)
}
func (p *staticKeySet) Close() error {
return nil
}
func (p *staticKeySet) Fetch(ctx context.Context) (KeySet, error) {
// clone to avoid any consumers making changes to the underlying keyset.
return p.keyset.Clone()
}
// NewStaticJWKKeyProvider create static key provider based on the keyset given.
func NewStaticJWKKeyProvider(keyset KeySet) KeyProvider {
return &staticKeySet{
keyset: keyset,
}
}
func publicKeyFromKeySet(keyset KeySet, kid, alg string) (interface{}, error) {
key, ok := keyset.LookupKeyID(kid)
if !ok {
return nil, errors.New("kid header does not exist in keyset")
}
if key.Algorithm() != alg {
return nil, errors.New("key from kid has different signing alg")
}
var pubKey interface{}
if err := key.Raw(&pubKey); err != nil {
return nil, fmt.Errorf("error getting raw key: %w", err)
}
return pubKey, nil
}