Skip to content
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ assert b"Hello world!" == recipient.decode(encoded, mac_key)
## You can get decoded protected/unprotected headers with the payload as follows:
# protected, unprotected, payload = recipient.decode_with_headers(encoded, mac_key)
# assert b"Hello world!" == payload

## Note that to pass header parameters with tstr labels, or tstr values, and avoid
# clashes with short-string names such as "alg" or value encoding to bstr, you can
# resolve the headers yourself, and pass a cwt.utils.ResolvedHeader({...}).
#
# For example:
# protected=cwt.utils.ResolvedHeader({
# "string label": "value"
# })
```

**CWT API**
Expand Down
4 changes: 2 additions & 2 deletions cwt/cbor_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Union

from cbor2 import dumps, loads

Expand All @@ -12,7 +12,7 @@ def _dumps(self, obj: Any) -> bytes:
except Exception as err:
raise EncodeError("Failed to encode.") from err

def _loads(self, s: bytes) -> Dict[int, Any]:
def _loads(self, s: bytes) -> Dict[Union[str, int], Any]:
try:
return loads(s)
except Exception as err:
Expand Down
46 changes: 24 additions & 22 deletions cwt/cose.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .recipient_interface import RecipientInterface
from .recipients import Recipients
from .signer import Signer
from .utils import sort_keys_for_deterministic_encoding, to_cose_header
from .utils import ResolvedHeader, sort_keys_for_deterministic_encoding, to_cose_header


class COSE(CBORProcessor):
Expand Down Expand Up @@ -132,8 +132,8 @@ def encode(
self,
payload: bytes,
key: Optional[COSEKeyInterface] = None,
protected: Optional[dict] = None,
unprotected: Optional[dict] = None,
protected: Optional[Union[dict, ResolvedHeader]] = None,
unprotected: Optional[Union[dict, ResolvedHeader]] = None,
recipients: List[RecipientInterface] = [],
signers: List[Signer] = [],
external_aad: bytes = b"",
Expand All @@ -146,8 +146,8 @@ def encode(
Args:
payload (bytes): A content to be MACed, signed or encrypted.
key (Optional[COSEKeyInterface]): A content encryption key as COSEKey.
protected (Optional[dict]): Parameters that are to be cryptographically protected.
unprotected (Optional[dict]): Parameters that are not cryptographically protected.
protected (Optional[Union[dict, ResolvedHeader]]): Parameters that are to be cryptographically protected.
unprotected (Optional[Union[dict, ResolvedHeader]]): Parameters that are not cryptographically protected.
recipients (List[RecipientInterface]): A list of recipient information structures.
signers (List[Signer]): A list of signer information objects for
multiple signer cases.
Expand Down Expand Up @@ -351,7 +351,7 @@ def decode_with_headers(
external_aad: bytes = b"",
detached_payload: Optional[bytes] = None,
enable_non_aead: bool = False,
) -> Tuple[Dict[int, Any], Dict[int, Any], bytes]:
) -> Tuple[Dict[Union[str, int], Any], Dict[Union[str, int], Any], bytes]:
"""
Verifies and decodes COSE data, and returns protected headers, unprotected headers and payload.

Expand All @@ -371,7 +371,7 @@ def decode_with_headers(
Since non-AEAD ciphers DO NOT provide neither authentication nor integrity
of decrypted message, make sure to validate them outside of this library.
Returns:
Tuple[Dict[int, Any], Dict[int, Any], bytes]: A dictionary data of decoded protected headers, and a dictionary data of unprotected headers, and a byte string of decoded payload.
Tuple[Dict[Union[str, int], Any], Dict[Union[str, int], Any], bytes]: A dictionary data of decoded protected headers, and a dictionary data of unprotected headers, and a byte string of decoded payload.
Raises:
ValueError: Invalid arguments.
DecodeError: Failed to decode data.
Expand Down Expand Up @@ -582,10 +582,10 @@ def decode_with_headers(
def _encode_headers(
self,
key: Optional[COSEKeyInterface],
protected: Optional[dict],
unprotected: Optional[dict],
protected: Optional[Union[dict, ResolvedHeader]],
unprotected: Optional[Union[dict, ResolvedHeader]],
enable_non_aead: bool,
) -> Tuple[Dict[int, Any], Dict[int, Any]]:
) -> Tuple[Dict[Union[str, int], Any], Dict[Union[str, int], Any]]:
p = to_cose_header(protected)
u = to_cose_header(unprotected)
if key is not None:
Expand All @@ -612,30 +612,32 @@ def _encode_headers(
raise ValueError("protected header MUST be zero-length")
return p, u

def _decode_headers(self, protected: Any, unprotected: Any) -> Tuple[Dict[int, Any], Dict[int, Any]]:
p: Union[Dict[int, Any], bytes]
def _decode_headers(
self, protected: Any, unprotected: Any
) -> Tuple[Dict[Union[str, int], Any], Dict[Union[str, int], Any]]:
p: Union[Dict[Union[str, int], Any], bytes]
p = self._loads(protected) if protected else {}
if isinstance(p, bytes):
if len(p) > 0:
raise ValueError("Invalid protected header.")
p = {}
u: Dict[int, Any] = unprotected
u: Dict[Union[str, int], Any] = unprotected
if not isinstance(u, dict):
raise ValueError("unprotected header should be dict.")
return p, u

def _validate_cose_message(
self,
key: Optional[COSEKeyInterface],
p: Dict[int, Any],
u: Dict[int, Any],
p: Dict[Union[str, int], Any],
u: Dict[Union[str, int], Any],
recipients: List[RecipientInterface],
signers: List[Signer],
) -> int:
if len(recipients) > 0 and len(signers) > 0:
raise ValueError("Both recipients and signers are specified.")

h: Dict[int, Any] = {}
h: Dict[Union[str, int], Any] = {}
iv_count: int = 0
for k, v in p.items():
if k == 2: # crit
Expand Down Expand Up @@ -745,8 +747,8 @@ def _encode_and_encrypt(
self,
payload: bytes,
key: Optional[COSEKeyInterface],
p: Dict[int, Any],
u: Dict[int, Any],
p: Dict[Union[str, int], Any],
u: Dict[Union[str, int], Any],
recipients: List[RecipientInterface],
external_aad: bytes,
out: str,
Expand Down Expand Up @@ -806,8 +808,8 @@ def _encode_and_mac(
self,
payload: bytes,
key: Optional[COSEKeyInterface],
p: Dict[int, Any],
u: Dict[int, Any],
p: Dict[Union[str, int], Any],
u: Dict[Union[str, int], Any],
recipients: List[RecipientInterface],
external_aad: bytes,
out: str,
Expand Down Expand Up @@ -849,8 +851,8 @@ def _encode_and_sign(
self,
payload: bytes,
key: Optional[COSEKeyInterface],
p: Dict[int, Any],
u: Dict[int, Any],
p: Dict[Union[str, int], Any],
u: Dict[Union[str, int], Any],
signers: List[Signer],
external_aad: bytes,
out: str,
Expand Down
6 changes: 3 additions & 3 deletions cwt/cose_message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

from cbor2 import CBORTag, loads

Expand Down Expand Up @@ -132,14 +132,14 @@ def type(self) -> COSETypes:
return self._type

@property
def protected(self) -> Dict[int, Any]:
def protected(self) -> Dict[Union[str, int], Any]:
"""
The protected headers as a CBOR object.
"""
return self._loads(self._protected)

@property
def unprotected(self) -> Dict[int, Any]:
def unprotected(self) -> Dict[Union[str, int], Any]:
"""
The unprotected headers as a CBOR object.
"""
Expand Down
10 changes: 5 additions & 5 deletions cwt/cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def decode(
data: bytes,
keys: Union[COSEKeyInterface, List[COSEKeyInterface]],
no_verify: bool = False,
) -> Union[Dict[int, Any], bytes]:
) -> Union[Dict[Union[str, int], Any], bytes]:
"""
Verifies and decodes CWT.

Expand All @@ -333,11 +333,11 @@ def decode(
DecodeError: Failed to decode the CWT.
VerifyError: Failed to verify the CWT.
"""
cwt: Union[bytes, CBORTag, Dict[int, Any]] = self._loads(data)
cwt: Union[bytes, CBORTag, Dict[Union[str, int], Any]] = self._loads(data)
if isinstance(cwt, CBORTag) and cwt.tag == CWT.CBOR_TAG:
cwt = cwt.value
keys = [keys] if isinstance(keys, COSEKeyInterface) else keys
p: Dict[int, Any] = {}
p: Dict[Union[str, int], Any] = {}
while isinstance(cwt, CBORTag):
p, u, cwt = self._cose.decode_with_headers(cwt, keys)
cwt = self._loads(cwt)
Expand Down Expand Up @@ -399,7 +399,7 @@ def _validate(self, claims: Union[Dict[int, Any], bytes]):
Claims.validate(claims)
return

def _verify(self, claims: Union[Dict[int, Any], bytes], protected: Dict[int, Any] = {}):
def _verify(self, claims: Union[Dict[Union[str, int], Any], bytes], protected: Dict[Union[str, int], Any] = {}):
if not isinstance(claims, dict):
raise DecodeError("Failed to decode.")

Expand Down Expand Up @@ -484,7 +484,7 @@ def decode(
data: bytes,
keys: Union[COSEKeyInterface, List[COSEKeyInterface]],
no_verify: bool = False,
) -> Union[Dict[int, Any], bytes]:
) -> Union[Dict[Union[str, int], Any], bytes]:
return _cwt.decode(data, keys, no_verify)


Expand Down
12 changes: 6 additions & 6 deletions cwt/recipient.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .recipient_algs.ecdh_direct_hkdf import ECDH_DirectHKDF
from .recipient_algs.hpke import HPKE
from .recipient_interface import RecipientInterface
from .utils import to_cose_header, to_recipient_context
from .utils import ResolvedHeader, to_cose_header, to_recipient_context


class Recipient:
Expand All @@ -30,8 +30,8 @@ class Recipient:
@classmethod
def new(
cls,
protected: dict = {},
unprotected: dict = {},
protected: Union[dict, ResolvedHeader] = {},
unprotected: Union[dict, ResolvedHeader] = {},
ciphertext: bytes = b"",
recipients: List[Any] = [],
sender_key: Optional[COSEKeyInterface] = None,
Expand All @@ -42,8 +42,8 @@ def new(
Creates a recipient from a CBOR-like dictionary with numeric keys.

Args:
protected (dict): Parameters that are to be cryptographically protected.
unprotected (dict): Parameters that are not cryptographically protected.
protected (Union[dict, ResolvedHeader]): Parameters that are to be cryptographically protected.
unprotected (Union[dict, ResolvedHeader]): Parameters that are not cryptographically protected.
ciphertext (List[Any]): A cipher text.
sender_key (Optional[COSEKeyInterface]): A sender private key as COSEKey.
recipient_key (Optional[COSEKeyInterface]): A recipient public key as COSEKey.
Expand Down Expand Up @@ -74,7 +74,7 @@ def new(
if alg == -6:
return DirectKey(p, u)
if alg in COSE_ALGORITHMS_KEY_WRAP.values():
if len(protected) > 0:
if len(p) > 0:
raise ValueError("The protected header must be a zero-length string in key wrap mode with an AE algorithm.")
if not sender_key:
sender_key = COSEKey.from_symmetric_key(alg=alg)
Expand Down
2 changes: 1 addition & 1 deletion cwt/recipient_algs/aes_key_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AESKeyWrap(RecipientInterface):

def __init__(
self,
unprotected: Dict[int, Any],
unprotected: Dict[Union[str, int], Any],
ciphertext: bytes = b"",
recipients: List[Any] = [],
sender_key: Optional[COSEKeyInterface] = None,
Expand Down
6 changes: 3 additions & 3 deletions cwt/recipient_algs/direct.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

from ..recipient_interface import RecipientInterface


class Direct(RecipientInterface):
def __init__(
self,
protected: Dict[int, Any],
unprotected: Dict[int, Any],
protected: Dict[Union[str, int], Any],
unprotected: Dict[Union[str, int], Any],
ciphertext: bytes = b"",
recipients: List[Any] = [],
):
Expand Down
4 changes: 2 additions & 2 deletions cwt/recipient_algs/direct_hkdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class DirectHKDF(Direct):

def __init__(
self,
protected: Dict[int, Any] = {},
unprotected: Dict[int, Any] = {},
protected: Dict[Union[str, int], Any] = {},
unprotected: Dict[Union[str, int], Any] = {},
context: List[Any] = [],
):
super().__init__(protected, unprotected, b"", [])
Expand Down
2 changes: 1 addition & 1 deletion cwt/recipient_algs/direct_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class DirectKey(Direct):
def __init__(self, protected: Dict[int, Any] = {}, unprotected: Dict[int, Any] = {}):
def __init__(self, protected: Dict[Union[str, int], Any] = {}, unprotected: Dict[Union[str, int], Any] = {}):
super().__init__(protected, unprotected, b"", [])

if self._alg != -6:
Expand Down
4 changes: 2 additions & 2 deletions cwt/recipient_algs/ecdh_aes_key_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class ECDH_AESKeyWrap(RecipientInterface):

def __init__(
self,
protected: Dict[int, Any],
unprotected: Dict[int, Any],
protected: Dict[Union[str, int], Any],
unprotected: Dict[Union[str, int], Any],
ciphertext: bytes = b"",
recipients: List[Any] = [],
sender_key: Optional[COSEKeyInterface] = None,
Expand Down
4 changes: 2 additions & 2 deletions cwt/recipient_algs/ecdh_direct_hkdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class ECDH_DirectHKDF(Direct):

def __init__(
self,
protected: Dict[int, Any],
unprotected: Dict[int, Any],
protected: Dict[Union[str, int], Any],
unprotected: Dict[Union[str, int], Any],
ciphertext: bytes = b"",
recipients: List[Any] = [],
sender_key: Optional[COSEKeyInterface] = None,
Expand Down
4 changes: 2 additions & 2 deletions cwt/recipient_algs/hpke.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def to_hpke_ciphersuites(alg: int) -> Tuple[int, int, int]:
class HPKE(RecipientInterface):
def __init__(
self,
protected: Dict[int, Any],
unprotected: Dict[int, Any],
protected: Dict[Union[str, int], Any],
unprotected: Dict[Union[str, int], Any],
ciphertext: bytes = b"",
recipients: List[Any] = [],
recipient_key: Optional[COSEKeyInterface] = None,
Expand Down
8 changes: 4 additions & 4 deletions cwt/recipient_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class RecipientInterface(CBORProcessor):

def __init__(
self,
protected: Optional[Dict[int, Any]] = None,
unprotected: Optional[Dict[int, Any]] = None,
protected: Optional[Dict[Union[str, int], Any]] = None,
unprotected: Optional[Dict[Union[str, int], Any]] = None,
ciphertext: bytes = b"",
recipients: List[Any] = [],
key_ops: List[int] = [],
Expand Down Expand Up @@ -106,7 +106,7 @@ def alg(self) -> int:
return self._alg

@property
def protected(self) -> Dict[int, Any]:
def protected(self) -> Dict[Union[str, int], Any]:
"""
The parameters that are to be cryptographically protected.
"""
Expand All @@ -122,7 +122,7 @@ def b_protected(self) -> bytes:
return self._b_protected

@property
def unprotected(self) -> Dict[int, Any]:
def unprotected(self) -> Dict[Union[str, int], Any]:
"""
The parameters that are not cryptographically protected.
"""
Expand Down
Loading
Loading