Skip to content

Commit c1b380c

Browse files
fix(core): Requires unique kids (#1905)
### Proposed Changes - When loading KAS keys, validate that key identifiers are unique - This will be required as we move to a key ID - first (instead of alg first) method of selecting keys - While here, moved some of the 'legacy conversion' code into more testable functions and generally improved testing and handling/detection of error conditions, including increasing test coverage of loading/displaying EC keys ### Checklist - [ ] I have added or updated unit tests - [ ] I have added or updated integration tests (if appropriate) - [ ] I have added or updated documentation ### Testing Instructions
1 parent 95372d3 commit c1b380c

File tree

7 files changed

+137
-109
lines changed

7 files changed

+137
-109
lines changed

service/internal/security/standard_crypto.go

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,11 @@ type StandardECCrypto struct {
6666
type keylist map[string]any
6767

6868
type StandardCrypto struct {
69-
// Lists of keys first sorted by algorithm
70-
keys map[string]keylist
69+
// Lists of keysByAlg first sorted by algorithm
70+
keysByAlg map[string]keylist
71+
72+
// Lists all keys by identifier.
73+
keysByID keylist
7174
}
7275

7376
// NewStandardCrypto Create a new instance of standard crypto
@@ -83,20 +86,26 @@ func NewStandardCrypto(cfg StandardConfig) (*StandardCrypto, error) {
8386
}
8487

8588
func loadKeys(ks []KeyPairInfo) (*StandardCrypto, error) {
86-
keys := make(map[string]keylist)
89+
keysByAlg := make(map[string]keylist)
90+
keysByID := make(keylist)
8791
for _, k := range ks {
8892
slog.Info("crypto cfg loading", "id", k.KID, "alg", k.Algorithm)
89-
if _, ok := keys[k.Algorithm]; !ok {
90-
keys[k.Algorithm] = make(map[string]any)
93+
if _, ok := keysByID[k.KID]; ok {
94+
return nil, fmt.Errorf("duplicate key identifier [%s]", k.KID)
95+
}
96+
if _, ok := keysByAlg[k.Algorithm]; !ok {
97+
keysByAlg[k.Algorithm] = make(map[string]any)
9198
}
9299
loadedKey, err := loadKey(k)
93100
if err != nil {
94101
return nil, err
95102
}
96-
keys[k.Algorithm][k.KID] = loadedKey
103+
keysByAlg[k.Algorithm][k.KID] = loadedKey
104+
keysByID[k.KID] = loadedKey
97105
}
98106
return &StandardCrypto{
99-
keys: keys,
107+
keysByAlg: keysByAlg,
108+
keysByID: keysByID,
100109
}, nil
101110
}
102111

@@ -139,13 +148,14 @@ func loadKey(k KeyPairInfo) (any, error) {
139148
}
140149

141150
func loadDeprecatedKeys(rsaKeys map[string]StandardKeyInfo, ecKeys map[string]StandardKeyInfo) (*StandardCrypto, error) {
142-
keys := make(map[string]keylist)
151+
keysByAlg := make(map[string]keylist)
152+
keysByID := make(keylist)
143153

144154
if len(ecKeys) > 0 {
145-
keys[AlgorithmECP256R1] = make(map[string]any)
155+
keysByAlg[AlgorithmECP256R1] = make(map[string]any)
146156
}
147157
if len(rsaKeys) > 0 {
148-
keys[AlgorithmRSA2048] = make(map[string]any)
158+
keysByAlg[AlgorithmRSA2048] = make(map[string]any)
149159
}
150160

151161
for id, kasInfo := range rsaKeys {
@@ -169,7 +179,7 @@ func loadDeprecatedKeys(rsaKeys map[string]StandardKeyInfo, ecKeys map[string]St
169179
return nil, fmt.Errorf("ocrypto.NewAsymEncryption failed: %w", err)
170180
}
171181

172-
keys[AlgorithmRSA2048][id] = StandardRSACrypto{
182+
k := StandardRSACrypto{
173183
KeyPairInfo: KeyPairInfo{
174184
Algorithm: AlgorithmRSA2048,
175185
KID: id,
@@ -179,6 +189,8 @@ func loadDeprecatedKeys(rsaKeys map[string]StandardKeyInfo, ecKeys map[string]St
179189
asymDecryption: asymDecryption,
180190
asymEncryption: asymEncryption,
181191
}
192+
keysByAlg[AlgorithmRSA2048][id] = k
193+
keysByID[id] = k
182194
}
183195
for id, kasInfo := range ecKeys {
184196
slog.Info("cfg.ECKeys", "id", id, "kasInfo", kasInfo)
@@ -192,7 +204,7 @@ func loadDeprecatedKeys(rsaKeys map[string]StandardKeyInfo, ecKeys map[string]St
192204
if err != nil {
193205
return nil, fmt.Errorf("failed to EC certificate file: %w", err)
194206
}
195-
keys[AlgorithmECP256R1][id] = StandardECCrypto{
207+
k := StandardECCrypto{
196208
KeyPairInfo: KeyPairInfo{
197209
Algorithm: AlgorithmRSA2048,
198210
KID: id,
@@ -202,15 +214,18 @@ func loadDeprecatedKeys(rsaKeys map[string]StandardKeyInfo, ecKeys map[string]St
202214
ecPrivateKeyPem: string(privatePemData),
203215
ecCertificatePEM: string(ecCertificatePEM),
204216
}
217+
keysByAlg[AlgorithmECP256R1][id] = k
218+
keysByID[id] = k
205219
}
206220

207221
return &StandardCrypto{
208-
keys: keys,
222+
keysByAlg: keysByAlg,
223+
keysByID: keysByID,
209224
}, nil
210225
}
211226

212227
func (s StandardCrypto) FindKID(alg string) string {
213-
if ks, ok := s.keys[alg]; ok && len(ks) > 0 {
228+
if ks, ok := s.keysByAlg[alg]; ok && len(ks) > 0 {
214229
for kid := range ks {
215230
return kid
216231
}
@@ -219,17 +234,13 @@ func (s StandardCrypto) FindKID(alg string) string {
219234
}
220235

221236
func (s StandardCrypto) RSAPublicKey(kid string) (string, error) {
222-
rsaKeys, ok := s.keys[AlgorithmRSA2048]
223-
if !ok || len(rsaKeys) == 0 {
224-
return "", ErrCertNotFound
225-
}
226-
k, ok := rsaKeys[kid]
237+
k, ok := s.keysByID[kid]
227238
if !ok {
228-
return "", ErrCertNotFound
239+
return "", fmt.Errorf("no rsa key with id [%s]: %w", kid, ErrCertNotFound)
229240
}
230241
rsa, ok := k.(StandardRSACrypto)
231242
if !ok {
232-
return "", ErrCertNotFound
243+
return "", fmt.Errorf("key with id [%s] is not an RSA key: %w", kid, ErrCertNotFound)
233244
}
234245

235246
pem, err := rsa.asymEncryption.PublicKeyInPemFormat()
@@ -241,27 +252,19 @@ func (s StandardCrypto) RSAPublicKey(kid string) (string, error) {
241252
}
242253

243254
func (s StandardCrypto) ECCertificate(kid string) (string, error) {
244-
ecKeys, ok := s.keys[AlgorithmECP256R1]
245-
if !ok || len(ecKeys) == 0 {
246-
return "", ErrCertNotFound
247-
}
248-
k, ok := ecKeys[kid]
255+
k, ok := s.keysByID[kid]
249256
if !ok {
250-
return "", ErrCertNotFound
257+
return "", fmt.Errorf("no ec key with id [%s]: %w", kid, ErrCertNotFound)
251258
}
252259
ec, ok := k.(StandardECCrypto)
253260
if !ok {
254-
return "", ErrCertNotFound
261+
return "", fmt.Errorf("key with id [%s] is not an EC key: %w", kid, ErrCertNotFound)
255262
}
256263
return ec.ecCertificatePEM, nil
257264
}
258265

259266
func (s StandardCrypto) ECPublicKey(kid string) (string, error) {
260-
ecKeys, ok := s.keys[AlgorithmECP256R1]
261-
if !ok || len(ecKeys) == 0 {
262-
return "", ErrCertNotFound
263-
}
264-
k, ok := ecKeys[kid]
267+
k, ok := s.keysByID[kid]
265268
if !ok {
266269
return "", ErrCertNotFound
267270
}
@@ -293,11 +296,7 @@ func (s StandardCrypto) ECPublicKey(kid string) (string, error) {
293296
}
294297

295298
func (s StandardCrypto) RSADecrypt(_ crypto.Hash, kid string, _ string, ciphertext []byte) ([]byte, error) {
296-
rsaKeys, ok := s.keys[AlgorithmRSA2048]
297-
if !ok || len(rsaKeys) == 0 {
298-
return nil, ErrCertNotFound
299-
}
300-
k, ok := rsaKeys[kid]
299+
k, ok := s.keysByID[kid]
301300
if !ok {
302301
return nil, ErrCertNotFound
303302
}
@@ -315,11 +314,10 @@ func (s StandardCrypto) RSADecrypt(_ crypto.Hash, kid string, _ string, cipherte
315314
}
316315

317316
func (s StandardCrypto) RSAPublicKeyAsJSON(kid string) (string, error) {
318-
rsaKeys, ok := s.keys[AlgorithmRSA2048]
319-
if !ok || len(rsaKeys) == 0 {
317+
k, ok := s.keysByID[kid]
318+
if !ok {
320319
return "", ErrCertNotFound
321320
}
322-
k, ok := rsaKeys[kid]
323321
if !ok {
324322
return "", ErrCertNotFound
325323
}
@@ -357,11 +355,7 @@ func (s StandardCrypto) GenerateNanoTDFSymmetricKey(kasKID string, ephemeralPubl
357355
}
358356
ephemeralECDSAPublicKeyPEM := pem.EncodeToMemory(pemBlock)
359357

360-
ecKeys, ok := s.keys[AlgorithmECP256R1]
361-
if !ok || len(ecKeys) == 0 {
362-
return nil, ErrNoKeys
363-
}
364-
k, ok := ecKeys[kasKID]
358+
k, ok := s.keysByID[kasKID]
365359
if !ok {
366360
return nil, ErrKeyPairInfoNotFound
367361
}

service/kas/access/provider.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,49 @@ func (p *Provider) IsReady(ctx context.Context) error {
5252
p.Logger.TraceContext(ctx, "checking readiness of kas service")
5353
return nil
5454
}
55+
56+
func (kasCfg *KASConfig) UpgradeMapToKeyring(c security.CryptoProvider) {
57+
switch {
58+
case kasCfg.ECCertID != "" && len(kasCfg.Keyring) > 0:
59+
panic("invalid kas cfg: please specify keyring or eccertid, not both")
60+
case len(kasCfg.Keyring) == 0:
61+
deprecatedOrDefault := func(kid, alg string) {
62+
if kid == "" {
63+
kid = c.FindKID(alg)
64+
}
65+
if kid == "" {
66+
// no known key for this algorithm type
67+
return
68+
}
69+
kasCfg.Keyring = append(kasCfg.Keyring, CurrentKeyFor{
70+
Algorithm: alg,
71+
KID: kid,
72+
})
73+
kasCfg.Keyring = append(kasCfg.Keyring, CurrentKeyFor{
74+
Algorithm: alg,
75+
KID: kid,
76+
Legacy: true,
77+
})
78+
}
79+
deprecatedOrDefault(kasCfg.ECCertID, security.AlgorithmECP256R1)
80+
deprecatedOrDefault(kasCfg.RSACertID, security.AlgorithmRSA2048)
81+
default:
82+
kasCfg.Keyring = append(kasCfg.Keyring, inferLegacyKeys(kasCfg.Keyring)...)
83+
}
84+
}
85+
86+
// If there exists *any* legacy keys, returns empty list.
87+
// Otherwise, create a copy with legacy=true for all values
88+
func inferLegacyKeys(keys []CurrentKeyFor) []CurrentKeyFor {
89+
for _, k := range keys {
90+
if k.Legacy {
91+
return nil
92+
}
93+
}
94+
l := make([]CurrentKeyFor, len(keys))
95+
for i, k := range keys {
96+
l[i] = k
97+
l[i].Legacy = true
98+
}
99+
return l
100+
}

service/kas/kidless_test.go renamed to service/kas/access/provider_test.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
package kas
1+
package access
22

33
import (
44
"testing"
55

66
"github.com/opentdf/platform/service/internal/security"
7-
"github.com/opentdf/platform/service/kas/access"
87
"github.com/stretchr/testify/assert"
98
)
109

@@ -13,14 +12,14 @@ func TestInferLegacyKeys_empty(t *testing.T) {
1312
}
1413

1514
func TestInferLegacyKeys_singles(t *testing.T) {
16-
one := []access.CurrentKeyFor{
15+
one := []CurrentKeyFor{
1716
{
1817
Algorithm: security.AlgorithmRSA2048,
1918
KID: "rsa",
2019
},
2120
}
2221

23-
oneLegacy := []access.CurrentKeyFor{
22+
oneLegacy := []CurrentKeyFor{
2423
{
2524
Algorithm: security.AlgorithmRSA2048,
2625
KID: "rsa",
@@ -34,7 +33,7 @@ func TestInferLegacyKeys_singles(t *testing.T) {
3433
}
3534

3635
func TestInferLegacyKeys_Mixed(t *testing.T) {
37-
in := []access.CurrentKeyFor{
36+
in := []CurrentKeyFor{
3837
{
3938
Algorithm: security.AlgorithmRSA2048,
4039
KID: "a",

service/kas/access/publicKey.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ func (p Provider) PublicKey(ctx context.Context, req *connect.Request[kaspb.Publ
9191
r := func(value, kid string, err error) (*connect.Response[kaspb.PublicKeyResponse], error) {
9292
if errors.Is(err, security.ErrCertNotFound) {
9393
p.Logger.ErrorContext(ctx, "no key found for", "err", err, "kid", kid, "algorithm", algorithm, "fmt", fmt)
94-
return nil, connect.NewError(connect.CodeNotFound, err)
94+
return nil, connect.NewError(connect.CodeNotFound, security.ErrCertNotFound)
9595
} else if err != nil {
9696
p.Logger.ErrorContext(ctx, "configuration error for key lookup", "err", err, "kid", kid, "algorithm", algorithm, "fmt", fmt)
97-
return nil, connect.NewError(connect.CodeInternal, err)
97+
return nil, connect.NewError(connect.CodeInternal, ErrInternal)
9898
}
9999
if req.Msg.GetV() == "1" {
100100
p.Logger.WarnContext(ctx, "hiding kid in public key response for legacy client", "kid", kid, "v", req.Msg.GetV())

0 commit comments

Comments
 (0)