Skip to content

Commit

Permalink
fix: ensure RSA key length fullfills 4096bit requirement (#2905) (#3402)
Browse files Browse the repository at this point in the history
Closes #2905

Co-authored-by: Arne <a.luenser@gmail.com>
Co-authored-by: aeneasr <3372410+aeneasr@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 16, 2023
1 parent af40d16 commit a663927
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 49 deletions.
19 changes: 10 additions & 9 deletions hsm/manager_hsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type KeyManager struct {
jwk.Manager
sync.RWMutex
Context
KeySetPrefix string
c config.DefaultProvider
}

var ErrPreGeneratedKeys = &fosite.RFC6749Error{
Expand All @@ -53,8 +53,8 @@ var ErrPreGeneratedKeys = &fosite.RFC6749Error{

func NewKeyManager(hsm Context, config *config.DefaultProvider) *KeyManager {
return &KeyManager{
Context: hsm,
KeySetPrefix: config.HSMKeySetPrefix(),
Context: hsm,
c: *config,
}
}

Expand Down Expand Up @@ -142,7 +142,7 @@ func (m *KeyManager) GetKey(ctx context.Context, set, kid string) (*jose.JSONWeb
return nil, errors.WithStack(x.ErrNotFound)
}

id, alg, use, err := getKeySetAttributes(m, keyPair, []byte(kid))
id, alg, use, err := m.getKeySetAttributes(ctx, keyPair, []byte(kid))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -174,7 +174,7 @@ func (m *KeyManager) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKe

var keys []jose.JSONWebKey
for _, keyPair := range keyPairs {
kid, alg, use, err := getKeySetAttributes(m, keyPair, nil)
kid, alg, use, err := m.getKeySetAttributes(ctx, keyPair, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -263,7 +263,7 @@ func (m *KeyManager) UpdateKeySet(_ context.Context, _ string, _ *jose.JSONWebKe
return errors.WithStack(ErrPreGeneratedKeys)
}

func getKeySetAttributes(m *KeyManager, key crypto11.Signer, kid []byte) (string, string, string, error) {
func (m *KeyManager) getKeySetAttributes(ctx context.Context, key crypto11.Signer, kid []byte) (string, string, string, error) {
if kid == nil {
ckaId, err := m.GetAttribute(key, crypto11.CkaId)
if err != nil {
Expand All @@ -276,8 +276,9 @@ func getKeySetAttributes(m *KeyManager, key crypto11.Signer, kid []byte) (string
switch k := key.Public().(type) {
case *rsa.PublicKey:
alg = "RS256"
// TODO Should we validate minimal key length by checking CKA_MODULUS_BITS?
// TODO see https://github.com/ory/hydra/issues/2905
if k.N.BitLen() < 4096 && !m.c.IsDevelopmentMode(ctx) {
return "", "", "", errors.WithStack(jwk.ErrMinimalRsaKeyLength)
}
case *ecdsa.PublicKey:
if k.Curve == elliptic.P521() {
alg = "ES512"
Expand Down Expand Up @@ -365,5 +366,5 @@ func createKeys(key crypto11.Signer, kid, alg, use string) []jose.JSONWebKey {
}

func (m *KeyManager) prefixKeySet(set string) string {
return fmt.Sprintf("%s%s", m.KeySetPrefix, set)
return fmt.Sprintf("%s%s", m.c.HSMKeySetPrefix(), set)
}
99 changes: 59 additions & 40 deletions hsm/manager_hsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,78 +62,97 @@ func TestKeyManager_HsmKeySetPrefix(t *testing.T) {
ctrl := gomock.NewController(t)
hsmContext := NewMockContext(ctrl)
defer ctrl.Finish()
l := logrusx.New("", "")
c := config.MustNew(context.Background(), l, configx.SkipValidation())
keySetPrefix := "application_specific_prefix."
c.MustSet(context.Background(), config.HSMKeySetPrefix, keySetPrefix)
m := hsm.NewKeyManager(hsmContext, c)

rsaKey, err := rsa.GenerateKey(rand.Reader, 512)
rsaKey3072, err := rsa.GenerateKey(rand.Reader, 3072)
require.NoError(t, err)
rsaKey4096, err := rsa.GenerateKey(rand.Reader, 4096)
require.NoError(t, err)

ecdsaKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
require.NoError(t, err)

rsaKeyPair := NewMockSignerDecrypter(ctrl)
rsaKeyPair.EXPECT().Public().Return(&rsaKey.PublicKey).AnyTimes()
rsaKeyPair3072 := NewMockSignerDecrypter(ctrl)
rsaKeyPair3072.EXPECT().Public().Return(&rsaKey3072.PublicKey).AnyTimes()

rsaKeyPair4096 := NewMockSignerDecrypter(ctrl)
rsaKeyPair4096.EXPECT().Public().Return(&rsaKey4096.PublicKey).AnyTimes()

ecdsaKeyPair := NewMockSignerDecrypter(ctrl)
ecdsaKeyPair.EXPECT().Public().Return(&ecdsaKey.PublicKey).AnyTimes()

var kid = uuid.New()

keySetPrefix := "application_specific_prefix."
expectedPrefixedOpenIDConnectKeyName := fmt.Sprintf("%s%s", keySetPrefix, x.OpenIDConnectKeyName)

m := &hsm.KeyManager{
Context: hsmContext,
KeySetPrefix: keySetPrefix,
}

t.Run("case=GenerateAndPersistKeySet", func(t *testing.T) {
privateAttrSet, publicAttrSet := expectedKeyAttributes(t, expectedPrefixedOpenIDConnectKeyName, kid)
hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return(nil, nil)
hsmContext.EXPECT().GenerateRSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(4096)).Return(rsaKeyPair, nil)
hsmContext.EXPECT().GenerateRSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(4096)).Return(rsaKeyPair4096, nil)

got, err := m.GenerateAndPersistKeySet(context.TODO(), x.OpenIDConnectKeyName, kid, "RS256", "sig")

assert.NoError(t, err)
expectedKeySet := expectedKeySet(rsaKeyPair, kid, "RS256", "sig")
expectedKeySet := expectedKeySet(rsaKeyPair4096, kid, "RS256", "sig")
if !reflect.DeepEqual(got, expectedKeySet) {
t.Errorf("GenerateAndPersistKeySet() got = %v, want %v", got, expectedKeySet)
}
})
t.Run("case=GetKey", func(t *testing.T) {
hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return(rsaKeyPair, nil)
hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)
hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return(rsaKeyPair4096, nil)
hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair4096), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)

got, err := m.GetKey(context.TODO(), x.OpenIDConnectKeyName, kid)

assert.NoError(t, err)
expectedKeySet := expectedKeySet(rsaKeyPair, kid, "RS256", "sig")
expectedKeySet := expectedKeySet(rsaKeyPair4096, kid, "RS256", "sig")
if !reflect.DeepEqual(got, expectedKeySet) {
t.Errorf("GetKey() got = %v, want %v", got, expectedKeySet)
}
})
t.Run("case=GetKeyMinimalRsaKeyLengthError", func(t *testing.T) {
hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return(rsaKeyPair3072, nil)

_, err := m.GetKey(context.TODO(), x.OpenIDConnectKeyName, kid)

assert.ErrorIs(t, err, jwk.ErrMinimalRsaKeyLength)
})
t.Run("case=GetKeySet", func(t *testing.T) {
hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return([]crypto11.Signer{rsaKeyPair}, nil)
hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(kid)), nil)
hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)
hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return([]crypto11.Signer{rsaKeyPair4096}, nil)
hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair4096), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(kid)), nil)
hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair4096), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)

got, err := m.GetKeySet(context.TODO(), x.OpenIDConnectKeyName)

assert.NoError(t, err)
expectedKeySet := expectedKeySet(rsaKeyPair, kid, "RS256", "sig")
expectedKeySet := expectedKeySet(rsaKeyPair4096, kid, "RS256", "sig")
if !reflect.DeepEqual(got, expectedKeySet) {
t.Errorf("GetKey() got = %v, want %v", got, expectedKeySet)
}
})
t.Run("case=GetKeySetMinimalRsaKeyLengthError", func(t *testing.T) {
hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return([]crypto11.Signer{rsaKeyPair3072}, nil)
hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair3072), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(kid)), nil)

_, err := m.GetKeySet(context.TODO(), x.OpenIDConnectKeyName)

assert.ErrorIs(t, err, jwk.ErrMinimalRsaKeyLength)
})
t.Run("case=DeleteKey", func(t *testing.T) {
hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return(rsaKeyPair, nil)
rsaKeyPair.EXPECT().Delete().Return(nil)
hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return(rsaKeyPair4096, nil)
rsaKeyPair4096.EXPECT().Delete().Return(nil)

err := m.DeleteKey(context.TODO(), x.OpenIDConnectKeyName, kid)

assert.NoError(t, err)
})
t.Run("case=DeleteKeySet", func(t *testing.T) {
hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return([]crypto11.Signer{rsaKeyPair}, nil)
rsaKeyPair.EXPECT().Delete().Return(nil)
hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return([]crypto11.Signer{rsaKeyPair4096}, nil)
rsaKeyPair4096.EXPECT().Delete().Return(nil)

err := m.DeleteKeySet(context.TODO(), x.OpenIDConnectKeyName)

Expand All @@ -145,8 +164,11 @@ func TestKeyManager_GenerateAndPersistKeySet(t *testing.T) {
ctrl := gomock.NewController(t)
hsmContext := NewMockContext(ctrl)
defer ctrl.Finish()
l := logrusx.New("", "")
c := config.MustNew(context.Background(), l, configx.SkipValidation())
m := hsm.NewKeyManager(hsmContext, c)

rsaKey, err := rsa.GenerateKey(rand.Reader, 512)
rsaKey, err := rsa.GenerateKey(rand.Reader, 4096)
require.NoError(t, err)

ecdsaKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
Expand Down Expand Up @@ -303,9 +325,6 @@ func TestKeyManager_GenerateAndPersistKeySet(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup(t)
m := &hsm.KeyManager{
Context: hsmContext,
}
got, err := m.GenerateAndPersistKeySet(tt.args.ctx, tt.args.set, tt.args.kid, tt.args.alg, tt.args.use)
if tt.wantErr != nil {
require.Nil(t, got)
Expand All @@ -326,8 +345,11 @@ func TestKeyManager_GetKey(t *testing.T) {
ctrl := gomock.NewController(t)
hsmContext := NewMockContext(ctrl)
defer ctrl.Finish()
l := logrusx.New("", "")
c := config.MustNew(context.Background(), l, configx.SkipValidation())
m := hsm.NewKeyManager(hsmContext, c)

rsaKey, err := rsa.GenerateKey(rand.Reader, 512)
rsaKey, err := rsa.GenerateKey(rand.Reader, 4096)
require.NoError(t, err)
rsaKeyPair := NewMockSignerDecrypter(ctrl)
rsaKeyPair.EXPECT().Public().Return(&rsaKey.PublicKey).AnyTimes()
Expand Down Expand Up @@ -493,9 +515,6 @@ func TestKeyManager_GetKey(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup(t)
m := &hsm.KeyManager{
Context: hsmContext,
}
got, err := m.GetKey(tt.args.ctx, tt.args.set, tt.args.kid)
if tt.wantErr != nil {
require.Nil(t, got)
Expand All @@ -516,8 +535,11 @@ func TestKeyManager_GetKeySet(t *testing.T) {
ctrl := gomock.NewController(t)
hsmContext := NewMockContext(ctrl)
defer ctrl.Finish()
l := logrusx.New("", "")
c := config.MustNew(context.Background(), l, configx.SkipValidation())
m := hsm.NewKeyManager(hsmContext, c)

rsaKey, err := rsa.GenerateKey(rand.Reader, 512)
rsaKey, err := rsa.GenerateKey(rand.Reader, 4096)
require.NoError(t, err)
rsaKid := uuid.New()
rsaKeyPair := NewMockSignerDecrypter(ctrl)
Expand Down Expand Up @@ -641,9 +663,6 @@ func TestKeyManager_GetKeySet(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup(t)
m := &hsm.KeyManager{
Context: hsmContext,
}
got, err := m.GetKeySet(tt.args.ctx, tt.args.set)
if tt.wantErr != nil {
require.Nil(t, got)
Expand All @@ -664,6 +683,9 @@ func TestKeyManager_DeleteKey(t *testing.T) {
ctrl := gomock.NewController(t)
hsmContext := NewMockContext(ctrl)
defer ctrl.Finish()
l := logrusx.New("", "")
c := config.MustNew(context.Background(), l, configx.SkipValidation())
m := hsm.NewKeyManager(hsmContext, c)

rsaKeyPair := NewMockSignerDecrypter(ctrl)

Expand Down Expand Up @@ -733,9 +755,6 @@ func TestKeyManager_DeleteKey(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup(t)
m := &hsm.KeyManager{
Context: hsmContext,
}
if err := m.DeleteKey(tt.args.ctx, tt.args.set, tt.args.kid); len(tt.wantErrMsg) != 0 {
require.EqualError(t, err, tt.wantErrMsg)
}
Expand All @@ -747,6 +766,9 @@ func TestKeyManager_DeleteKeySet(t *testing.T) {
ctrl := gomock.NewController(t)
hsmContext := NewMockContext(ctrl)
defer ctrl.Finish()
l := logrusx.New("", "")
c := config.MustNew(context.Background(), l, configx.SkipValidation())
m := hsm.NewKeyManager(hsmContext, c)

rsaKeyPair1 := NewMockSignerDecrypter(ctrl)
rsaKeyPair2 := NewMockSignerDecrypter(ctrl)
Expand Down Expand Up @@ -812,9 +834,6 @@ func TestKeyManager_DeleteKeySet(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup(t)
m := &hsm.KeyManager{
Context: hsmContext,
}
if err := m.DeleteKeySet(tt.args.ctx, tt.args.set); len(tt.wantErrMsg) != 0 {
require.EqualError(t, err, tt.wantErrMsg)
}
Expand Down
6 changes: 6 additions & 0 deletions jwk/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ var ErrUnsupportedEllipticCurve = &fosite.RFC6749Error{
DescriptionField: "Unsupported elliptic curve",
}

var ErrMinimalRsaKeyLength = &fosite.RFC6749Error{
CodeField: http.StatusBadRequest,
ErrorField: http.StatusText(http.StatusBadRequest),
DescriptionField: "Unsupported RSA key length",
}

type (
Manager interface {
GenerateAndPersistKeySet(ctx context.Context, set, kid, alg, use string) (*jose.JSONWebKeySet, error)
Expand Down

0 comments on commit a663927

Please sign in to comment.