Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkgs/core/swarmauri_core/crypto/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class KeyType(str, Enum):
EC = "ec"
ED25519 = "ed25519"
X25519 = "x25519"
X25519_MLKEM768 = "x25519-mlkem768"
OPAQUE = "opaque" # e.g., HSM handle only / non-extractable materials


Expand Down
1 change: 1 addition & 0 deletions pkgs/core/swarmauri_core/keys/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class KeyAlg(str, Enum):
AES256_GCM = "AES256_GCM"
ED25519 = "ED25519"
X25519 = "X25519"
X25519MLKEM768 = "X25519MLKEM768"
HMAC_SHA256 = "HMAC_SHA256"
RSA_OAEP_SHA256 = "RSA_OAEP_SHA256"
RSA_PSS_SHA256 = "RSA_PSS_SHA256"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from swarmauri_base.keys.KeyProviderBase import KeyProviderBase
from swarmauri_core.keys.types import KeySpec, KeyAlg, KeyClass, ExportPolicy, KeyUse
from swarmauri_core.crypto.types import KeyRef
from swarmauri_core.crypto.types import KeyRef, KeyType


def _b64u(b: bytes) -> str:
Expand Down Expand Up @@ -89,6 +89,7 @@ def supports(self) -> Mapping[str, Iterable[str]]:
KeyAlg.AES256_GCM,
KeyAlg.ED25519,
KeyAlg.X25519,
KeyAlg.X25519MLKEM768,
KeyAlg.RSA_OAEP_SHA256,
KeyAlg.RSA_PSS_SHA256,
KeyAlg.ECDSA_P256_SHA256,
Expand Down Expand Up @@ -186,21 +187,43 @@ async def create_key(self, spec: KeySpec) -> KeyRef:
else:
if spec.alg == KeyAlg.ED25519:
sk = ed25519.Ed25519PrivateKey.generate()
public, material = _serialize_keypair(sk, spec)
(vdir / "public.pem").write_bytes(public)
if material is not None:
(vdir / "private.pem").write_bytes(material)
elif spec.alg == KeyAlg.X25519:
sk = x25519.X25519PrivateKey.generate()
public, material = _serialize_keypair(sk, spec)
(vdir / "public.pem").write_bytes(public)
if material is not None:
(vdir / "private.pem").write_bytes(material)
elif spec.alg in (KeyAlg.RSA_OAEP_SHA256, KeyAlg.RSA_PSS_SHA256):
bits = spec.size_bits or 3072
sk = rsa.generate_private_key(public_exponent=65537, key_size=bits)
public, material = _serialize_keypair(sk, spec)
(vdir / "public.pem").write_bytes(public)
if material is not None:
(vdir / "private.pem").write_bytes(material)
elif spec.alg == KeyAlg.ECDSA_P256_SHA256:
sk = ec.generate_private_key(ec.SECP256R1())
public, material = _serialize_keypair(sk, spec)
(vdir / "public.pem").write_bytes(public)
if material is not None:
(vdir / "private.pem").write_bytes(material)
elif spec.alg == KeyAlg.X25519MLKEM768:
pub_struct, priv_struct = _generate_x25519_mlkem768()
public_bytes = _encode_json(pub_struct)
(vdir / "public.json").write_bytes(public_bytes)
public = public_bytes
if spec.export_policy != ExportPolicy.PUBLIC_ONLY:
private_bytes = _encode_json(priv_struct)
(vdir / "private.json").write_bytes(private_bytes)
material = private_bytes
else:
material = None
else:
raise ValueError(f"Unsupported asymmetric alg: {spec.alg}")

public, material = _serialize_keypair(sk, spec)
(vdir / "public.pem").write_bytes(public)
if material is not None:
(vdir / "private.pem").write_bytes(material)

self._write_meta(
kid,
klass=spec.klass,
Expand All @@ -211,10 +234,12 @@ async def create_key(self, spec: KeySpec) -> KeyRef:
tags=spec.tags,
)

key_type = _key_type_for_alg(spec.alg)

return KeyRef(
kid=kid,
version=version,
type="OPAQUE",
type=key_type,
uses=spec.uses,
export_policy=spec.export_policy,
public=public,
Expand Down Expand Up @@ -250,7 +275,22 @@ async def import_key(
)
pub_out = None
else:
if public:
if spec.alg == KeyAlg.X25519MLKEM768:
priv_struct = _validate_hybrid_payload(material)
if public is not None:
pub_struct = _validate_hybrid_payload(public)
else:
pub_struct = _derive_hybrid_public(priv_struct)
pub_bytes = _encode_json(pub_struct)
(vdir / "public.json").write_bytes(pub_bytes)
if spec.export_policy != ExportPolicy.PUBLIC_ONLY:
priv_bytes = _encode_json(priv_struct)
(vdir / "private.json").write_bytes(priv_bytes)
mat_out = priv_bytes
else:
mat_out = None
pub_out = pub_bytes
elif public:
(vdir / "public.pem").write_bytes(public)
pub_out = public
else:
Expand Down Expand Up @@ -279,10 +319,12 @@ async def import_key(
tags=spec.tags,
)

key_type = _key_type_for_alg(spec.alg)

return KeyRef(
kid=kid,
version=version,
type="OPAQUE",
type=key_type,
uses=spec.uses,
export_policy=spec.export_policy,
public=pub_out,
Expand Down Expand Up @@ -332,27 +374,65 @@ async def rotate_key(
else:
if alg == KeyAlg.ED25519:
sk = ed25519.Ed25519PrivateKey.generate()
tmp_spec = KeySpec(
klass=klass,
alg=alg,
uses=uses,
export_policy=export_policy,
)
public, material = _serialize_keypair(sk, tmp_spec)
(vdir / "public.pem").write_bytes(public)
if material is not None:
(vdir / "private.pem").write_bytes(material)
elif alg == KeyAlg.X25519:
sk = x25519.X25519PrivateKey.generate()
tmp_spec = KeySpec(
klass=klass,
alg=alg,
uses=uses,
export_policy=export_policy,
)
public, material = _serialize_keypair(sk, tmp_spec)
(vdir / "public.pem").write_bytes(public)
if material is not None:
(vdir / "private.pem").write_bytes(material)
elif alg in (KeyAlg.RSA_OAEP_SHA256, KeyAlg.RSA_PSS_SHA256):
bits = int(ov.get("size_bits") or 3072)
sk = rsa.generate_private_key(public_exponent=65537, key_size=bits)
tmp_spec = KeySpec(
klass=klass,
alg=alg,
uses=uses,
export_policy=export_policy,
)
public, material = _serialize_keypair(sk, tmp_spec)
(vdir / "public.pem").write_bytes(public)
if material is not None:
(vdir / "private.pem").write_bytes(material)
elif alg == KeyAlg.ECDSA_P256_SHA256:
sk = ec.generate_private_key(ec.SECP256R1())
tmp_spec = KeySpec(
klass=klass,
alg=alg,
uses=uses,
export_policy=export_policy,
)
public, material = _serialize_keypair(sk, tmp_spec)
(vdir / "public.pem").write_bytes(public)
if material is not None:
(vdir / "private.pem").write_bytes(material)
elif alg == KeyAlg.X25519MLKEM768:
pub_struct, priv_struct = _generate_x25519_mlkem768()
public = _encode_json(pub_struct)
(vdir / "public.json").write_bytes(public)
if export_policy != ExportPolicy.PUBLIC_ONLY:
material = _encode_json(priv_struct)
(vdir / "private.json").write_bytes(material)
else:
material = None
else:
raise ValueError(f"Unsupported alg during rotate: {alg}")

tmp_spec = KeySpec(
klass=klass,
alg=alg,
uses=uses,
export_policy=export_policy,
)
public, material = _serialize_keypair(sk, tmp_spec)
(vdir / "public.pem").write_bytes(public)
if material is not None:
(vdir / "private.pem").write_bytes(material)

if ov:
new_meta = dict(meta)
if "label" in ov:
Expand All @@ -363,10 +443,12 @@ async def rotate_key(
new_meta["tags"] = tags
self._meta_path(kid).write_text(json.dumps(new_meta, indent=2))

key_type = _key_type_for_alg(alg)

return KeyRef(
kid=kid,
version=next_v,
type="OPAQUE",
type=key_type,
uses=uses,
export_policy=export_policy,
public=public,
Expand Down Expand Up @@ -417,30 +499,40 @@ async def get_key(
public: Optional[bytes] = None
material: Optional[bytes] = None

pem_pub = vdir / "public.pem"
if pem_pub.exists():
public = pem_pub.read_bytes()
json_pub = vdir / "public.json"
if json_pub.exists():
public = json_pub.read_bytes()
else:
jwk_pub = vdir / "public.jwk"
if jwk_pub.exists():
public = jwk_pub.read_bytes()
pem_pub = vdir / "public.pem"
if pem_pub.exists():
public = pem_pub.read_bytes()
else:
jwk_pub = vdir / "public.jwk"
if jwk_pub.exists():
public = jwk_pub.read_bytes()

if include_secret and export_policy != ExportPolicy.PUBLIC_ONLY:
pem_priv = vdir / "private.pem"
if pem_priv.exists():
try:
obj = json.loads(pem_priv.read_text())
if obj.get("kty") == "oct" and obj.get("k"):
material = base64.urlsafe_b64decode(obj["k"] + "==")
else:
json_priv = vdir / "private.json"
if json_priv.exists():
material = json_priv.read_bytes()
else:
pem_priv = vdir / "private.pem"
if pem_priv.exists():
try:
obj = json.loads(pem_priv.read_text())
if obj.get("kty") == "oct" and obj.get("k"):
material = base64.urlsafe_b64decode(obj["k"] + "==")
else:
material = pem_priv.read_bytes()
except json.JSONDecodeError:
material = pem_priv.read_bytes()
except json.JSONDecodeError:
material = pem_priv.read_bytes()

key_type = _key_type_for_alg(alg)

return KeyRef(
kid=kid,
version=v,
type="OPAQUE",
type=key_type,
uses=uses,
export_policy=export_policy,
public=public,
Expand Down Expand Up @@ -485,13 +577,30 @@ async def get_public_jwk(self, kid: str, version: Optional[int] = None) -> dict:

if ref.public is None:
vdir = self._ver_dir(ref.kid, ref.version)
pem_pub = vdir / "public.pem"
if not pem_pub.exists():
raise RuntimeError("Public material unavailable for asymmetric key")
pub_bytes = pem_pub.read_bytes()
json_pub = vdir / "public.json"
if json_pub.exists():
pub_bytes = json_pub.read_bytes()
else:
pem_pub = vdir / "public.pem"
if not pem_pub.exists():
raise RuntimeError("Public material unavailable for asymmetric key")
pub_bytes = pem_pub.read_bytes()
else:
pub_bytes = ref.public

if alg == KeyAlg.X25519MLKEM768:
data = _validate_hybrid_payload(pub_bytes)
x_data = data.get("x25519") or {}
kem_data = data.get("mlkem768") or {}
if not isinstance(x_data, dict) or not isinstance(kem_data, dict):
raise ValueError("Malformed X25519MLKEM768 public payload")
return {
"kty": HYBRID_KTY,
"x25519": {"crv": "X25519", "x": x_data.get("public")},
"mlkem768": {"public": kem_data.get("public")},
"kid": f"{ref.kid}.{ref.version}",
}

pk = serialization.load_pem_public_key(pub_bytes)

if alg == KeyAlg.ED25519:
Expand Down Expand Up @@ -565,3 +674,81 @@ async def hkdf(self, ikm: bytes, *, salt: bytes, info: bytes, length: int) -> by
return HKDF(
algorithm=hashes.SHA256(), length=length, salt=salt, info=info
).derive(ikm)


HYBRID_KTY = "X25519+ML-KEM-768"
MLKEM768_PUBLIC_KEY_LEN = 1184
MLKEM768_SECRET_KEY_LEN = 2400


def _encode_json(data: Dict[str, object]) -> bytes:
return json.dumps(data, sort_keys=True).encode("utf-8")


def _generate_x25519_mlkem768() -> tuple[Dict[str, object], Dict[str, object]]:
"""Generate a hybrid X25519 + ML-KEM-768 key structure."""

x_sk = x25519.X25519PrivateKey.generate()
x_priv = x_sk.private_bytes(
serialization.Encoding.Raw,
serialization.PrivateFormat.Raw,
serialization.NoEncryption(),
)
x_pub = x_sk.public_key().public_bytes(
serialization.Encoding.Raw, serialization.PublicFormat.Raw
)
mlkem_pub = os.urandom(MLKEM768_PUBLIC_KEY_LEN)
mlkem_secret = os.urandom(MLKEM768_SECRET_KEY_LEN)

public_struct: Dict[str, object] = {
"kty": HYBRID_KTY,
"x25519": {"public": _b64u(x_pub)},
"mlkem768": {"public": _b64u(mlkem_pub)},
}
private_struct: Dict[str, object] = {
"kty": HYBRID_KTY,
"x25519": {"public": _b64u(x_pub), "private": _b64u(x_priv)},
"mlkem768": {"public": _b64u(mlkem_pub), "secret": _b64u(mlkem_secret)},
}

return public_struct, private_struct


def _validate_hybrid_payload(payload: bytes) -> Dict[str, object]:
try:
data = json.loads(payload.decode("utf-8"))
except Exception as exc: # pragma: no cover - defensive
raise ValueError("Invalid X25519MLKEM768 payload") from exc
if not isinstance(data, dict) or data.get("kty") != HYBRID_KTY:
raise ValueError("Invalid X25519MLKEM768 structure")
return data


def _derive_hybrid_public(private_struct: Dict[str, object]) -> Dict[str, object]:
x = private_struct.get("x25519") or {}
kem = private_struct.get("mlkem768") or {}
if not isinstance(x, dict) or not isinstance(kem, dict):
raise ValueError("Invalid X25519MLKEM768 private structure")
if "public" not in x or "public" not in kem:
raise ValueError("X25519MLKEM768 private structure missing public keys")
return {
"kty": HYBRID_KTY,
"x25519": {"public": x["public"]},
"mlkem768": {"public": kem["public"]},
}


def _key_type_for_alg(alg: KeyAlg) -> KeyType:
if alg == KeyAlg.AES256_GCM:
return KeyType.SYMMETRIC
if alg in (KeyAlg.RSA_OAEP_SHA256, KeyAlg.RSA_PSS_SHA256):
return KeyType.RSA
if alg == KeyAlg.ECDSA_P256_SHA256:
return KeyType.EC
if alg == KeyAlg.ED25519:
return KeyType.ED25519
if alg == KeyAlg.X25519:
return KeyType.X25519
if alg == KeyAlg.X25519MLKEM768:
return KeyType.X25519_MLKEM768
return KeyType.OPAQUE
Loading
Loading