Skip to content

Commit 56b2881

Browse files
committed
Add union and intersection for KeyJar
1 parent 536ff4c commit 56b2881

File tree

2 files changed

+107
-8
lines changed

2 files changed

+107
-8
lines changed

src/cryptojwt/key_jar.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import contextlib
22
import json
33
import logging
4+
from collections import defaultdict
45
from typing import List, Optional
56

67
from requests import request
@@ -621,6 +622,53 @@ def copy(self):
621622
def __len__(self):
622623
return len(self._issuers)
623624

625+
def union(self, *args) -> "KeyJar":
626+
"""Return new KeyJar which is the union of self and all args"""
627+
628+
issuer_keys: dict[str, set[JWK]] = defaultdict(set)
629+
630+
for _id, _issuer in self._issuers.items():
631+
issuer_keys[_id] |= set(_issuer.all_keys())
632+
633+
for key_jar in args:
634+
for _id, _issuer in key_jar.items():
635+
issuer_keys[_id] |= set(_issuer.all_keys())
636+
637+
res = KeyJar()
638+
for _id, keys in issuer_keys.items():
639+
kb = KeyBundle()
640+
kb.set(keys)
641+
res.add_kb(_id, kb)
642+
643+
return res
644+
645+
def __or__(self, other) -> "KeyJar":
646+
return self.union(other)
647+
648+
def intersection(self, *args) -> "KeyJar":
649+
"""Return new KeyJar which is the intersection of self and all args"""
650+
651+
issuer_keys: dict[str, set[JWK]] = defaultdict(set)
652+
653+
for _id, _issuer in self._issuers.items():
654+
issuer_keys[_id] |= set(_issuer.all_keys())
655+
656+
for key_jar in args:
657+
for _id, _issuer in key_jar.items():
658+
issuer_keys[_id] &= set(_issuer.all_keys())
659+
660+
res = KeyJar()
661+
for _id, keys in issuer_keys.items():
662+
if keys:
663+
kb = KeyBundle()
664+
kb.set(keys)
665+
res.add_kb(_id, kb)
666+
667+
return res
668+
669+
def __and__(self, other) -> "KeyJar":
670+
return self.intersection(other)
671+
624672
def _dump_issuers(
625673
self,
626674
exclude_issuers: Optional[List[str]] = None,

tests/test_04_key_jar.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from cryptojwt.exception import IssuerNotFound, JWKESTException
1010
from cryptojwt.jwe.jwenc import JWEnc
11+
from cryptojwt.jwk.okp import new_okp_key
1112
from cryptojwt.jws.jws import JWS, factory
1213
from cryptojwt.key_bundle import KeyBundle, keybundle_from_local_file, rsa_init
1314
from cryptojwt.key_jar import KeyJar, build_keyjar, init_key_jar, rotate_keys
@@ -1020,13 +1021,63 @@ def test_contains():
10201021
assert "David" not in kj
10211022

10221023

1023-
def test_similar():
1024-
ISSUER = "xyzzy"
1024+
def test_union():
1025+
kj1 = KeyJar()
1026+
kj2 = KeyJar()
1027+
kj3 = KeyJar()
10251028

1026-
kj = KeyJar()
1027-
kb = KeyBundle(JWK2)
1028-
kj.add_kb(issuer_id=ISSUER, kb=kb)
1029+
alice_keys1 = [
1030+
new_okp_key(crv="Ed25519", kid="key1a").serialize(private=True),
1031+
new_okp_key(crv="Ed25519", kid="key1b").serialize(private=True),
1032+
]
1033+
1034+
alice_keys2 = [
1035+
new_okp_key(crv="Ed25519", kid="key2a").serialize(private=True),
1036+
new_okp_key(crv="Ed25519", kid="key2b").serialize(private=True),
1037+
]
1038+
1039+
bob_keys1 = [
1040+
new_okp_key(crv="Ed25519", kid="key1a").serialize(private=True),
1041+
new_okp_key(crv="Ed25519", kid="key1b").serialize(private=True),
1042+
new_okp_key(crv="Ed25519", kid="key2a").serialize(private=True),
1043+
new_okp_key(crv="Ed25519", kid="key2b").serialize(private=True),
1044+
]
1045+
1046+
kj1.add_kb("Alice", KeyBundle(keys=alice_keys1))
1047+
kj2.add_kb("Alice", KeyBundle(keys=alice_keys2))
1048+
kj3.add_kb("Bob", KeyBundle(keys=bob_keys1))
1049+
1050+
kj = kj1 | kj2 | kj3
1051+
1052+
assert len(kj["Alice"].all_keys()) == 4
1053+
assert len(kj["Bob"].all_keys()) == 4
1054+
1055+
1056+
def test_intersection():
1057+
kj1 = KeyJar()
1058+
kj2 = KeyJar()
1059+
1060+
alice_keys = [
1061+
new_okp_key(crv="Ed25519", kid="key1").serialize(private=True),
1062+
new_okp_key(crv="Ed25519", kid="key2").serialize(private=True),
1063+
new_okp_key(crv="Ed25519", kid="key3").serialize(private=True),
1064+
new_okp_key(crv="Ed25519", kid="key4").serialize(private=True),
1065+
]
1066+
1067+
bob_keys = [
1068+
new_okp_key(crv="Ed25519", kid="key1").serialize(private=True),
1069+
new_okp_key(crv="Ed25519", kid="key2").serialize(private=True),
1070+
new_okp_key(crv="Ed25519", kid="key3").serialize(private=True),
1071+
new_okp_key(crv="Ed25519", kid="key4").serialize(private=True),
1072+
]
1073+
1074+
kj1.add_kb("Alice", KeyBundle(keys=[alice_keys[0], alice_keys[1]]))
1075+
kj1.add_kb("Bob", KeyBundle(keys=[bob_keys[0], bob_keys[1]]))
1076+
1077+
kj2.add_kb("Alice", KeyBundle(keys=[alice_keys[1], alice_keys[2], alice_keys[3]]))
1078+
kj2.add_kb("Bob", KeyBundle(keys=[bob_keys[0], bob_keys[1]]))
1079+
1080+
kj = kj1 & kj2
10291081

1030-
keys1 = kj.get_issuer_keys(ISSUER)
1031-
keys2 = kj[ISSUER].all_keys()
1032-
assert keys1 == keys2
1082+
assert len(kj["Alice"].all_keys()) == 1
1083+
assert len(kj["Bob"].all_keys()) == 2

0 commit comments

Comments
 (0)