diff --git a/src/pyndef/ndef.py b/src/pyndef/ndef.py index 4254fce..4e0ab72 100644 --- a/src/pyndef/ndef.py +++ b/src/pyndef/ndef.py @@ -1,5 +1,5 @@ import enum -from typing import List, Tuple, Optional, Union +from typing import List, Tuple, Optional, Union, Iterator from urllib.parse import urlparse @@ -95,16 +95,32 @@ class NdefRecord: _MAX_PAYLOAD_SIZE: int = 10 * (1 << 20) def __init__(self, tnf: NdefTNF, record_type: Union[NdefRTD, bytes, None], record_id: Optional[bytes], payload: Optional[bytes]) -> None: - self.tnf: NdefTNF = tnf - self.record_type: bytes = bytes(record_type) if record_type is not None else bytes() - self.record_id: bytes = bytes(record_id) if record_id is not None else bytes() - self.payload: bytes = bytes(payload) if payload is not None else bytes() + self._tnf: NdefTNF = tnf + self._record_type: bytes = bytes(record_type) if record_type is not None else bytes() + self._record_id: bytes = bytes(record_id) if record_id is not None else bytes() + self._payload: bytes = bytes(payload) if payload is not None else bytes() - self._validate_tnf(tnf, self.record_type, self.record_id, self.payload) + self._validate_tnf(tnf, self._record_type, self._record_id, self._payload) + + @property + def tnf(self) -> NdefTNF: + return self._tnf + + @property + def record_type(self) -> bytes: + return self._record_type + + @property + def record_id(self) -> bytes: + return self._record_id + + @property + def payload(self) -> bytes: + return self._payload def to_known_rtd(self) -> Optional[NdefRTD]: try: - return NdefRTD(self.record_type) + return NdefRTD(self._record_type) except ValueError: return None @@ -196,11 +212,7 @@ def _ensure_sane_payload_size(size: int) -> None: cf = flag & NdefRecord._FLAG_CF != 0 sr = flag & NdefRecord._FLAG_SR != 0 il = flag & NdefRecord._FLAG_IL != 0 - raw_tnf = flag & 0x07 - try: - tnf = NdefTNF(raw_tnf) - except ValueError as e: - raise ValueError(f"unexpected tnf value: 0x{raw_tnf:02x}") from e + tnf = NdefTNF(flag & 0x07) if not mb and len(records) == 0 and not in_chunk and not ignore_mb_me: raise ValueError("expected MB flag") @@ -257,10 +269,7 @@ def _ensure_sane_payload_size(size: int) -> None: payload_length += len(chunk) _ensure_sane_payload_size(payload_length) payload = b''.join([chunk for chunk in chunks]) - if chunk_tnf is not None: - tnf = chunk_tnf - else: - raise ValueError("unknown chunk tnf") + tnf = chunk_tnf if cf: in_chunk = True @@ -268,11 +277,7 @@ def _ensure_sane_payload_size(size: int) -> None: else: in_chunk = False - if record_type is not None and record_id is not None: - NdefRecord._validate_tnf(tnf, record_type, record_id, payload) - else: - raise ValueError("unknown record type or record id") - + NdefRecord._validate_tnf(tnf, record_type, record_id, payload) records.append(NdefRecord(tnf, record_type, record_id, payload)) if ignore_mb_me: @@ -287,18 +292,18 @@ def _ensure_sane_payload_size(size: int) -> None: @property def _flag_sr(self) -> bool: - return len(self.payload) < 256 + return len(self._payload) < 256 @property def _flag_il(self) -> bool: - return True if self.tnf == NdefTNF.EMPTY else len(self.record_id) > 0 + return True if self._tnf == NdefTNF.EMPTY else len(self._record_id) > 0 def to_mime_type(self) -> Optional[str]: - if self.tnf == NdefTNF.WELL_KNOWN: - if self.record_type == NdefRTD.TEXT: + if self._tnf == NdefTNF.WELL_KNOWN: + if self._record_type == NdefRTD.TEXT: return "text/plain" - elif self.tnf == NdefTNF.MIME_MEDIA: - raw_mime_type = self.record_type.decode("ascii") + elif self._tnf == NdefTNF.MIME_MEDIA: + raw_mime_type = self._record_type.decode("ascii") return _normalize_mime_type(raw_mime_type) return None @@ -306,22 +311,22 @@ def to_uri(self) -> Optional[str]: return self._to_uri(False) def _to_uri(self, in_smart_poster: bool) -> Optional[str]: - if self.tnf == NdefTNF.WELL_KNOWN: - if self.record_type == NdefRTD.SMART_POSTER and not in_smart_poster: - for record in NdefMessage.parse(self.payload).records: + if self._tnf == NdefTNF.WELL_KNOWN: + if self._record_type == NdefRTD.SMART_POSTER and not in_smart_poster: + for record in NdefMessage.parse(self._payload): uri = record._to_uri(True) if uri is not None: return _normalize_uri_scheme(uri) - elif self.record_type == NdefRTD.URI: - if len(self.payload) >= 2: - prefix_index = self.payload[0] + elif self._record_type == NdefRTD.URI: + if len(self._payload) >= 2: + prefix_index = self._payload[0] if 0 <= prefix_index < len(_URI_PREFIX_MAP): - return _URI_PREFIX_MAP[prefix_index] + self.payload[1:].decode("utf-8") - elif self.tnf == NdefTNF.ABSOLUTE_URI: - return _normalize_uri_scheme(self.record_type.decode("utf-8")) - elif self.tnf == NdefTNF.EXTERNAL_TYPE: + return _URI_PREFIX_MAP[prefix_index] + self._payload[1:].decode("utf-8") + elif self._tnf == NdefTNF.ABSOLUTE_URI: + return _normalize_uri_scheme(self._record_type.decode("utf-8")) + elif self._tnf == NdefTNF.EXTERNAL_TYPE: if not in_smart_poster: - return "vnd.android.nfc://ext/" + self.record_type.decode("ascii") + return "vnd.android.nfc://ext/" + self._record_type.decode("ascii") return None def to_bytes(self, flag_mb: bool = True, flag_me: bool = True) -> bytes: @@ -330,21 +335,21 @@ def to_bytes(self, flag_mb: bool = True, flag_me: bool = True) -> bytes: flag = (self._FLAG_MB if flag_mb else 0) | \ (self._FLAG_ME if flag_me else 0) | \ (self._FLAG_SR if self._flag_sr else 0) | \ - (self._FLAG_IL if self._flag_il else 0) | self.tnf.value + (self._FLAG_IL if self._flag_il else 0) | self._tnf.value buffer.append(flag) - buffer.append(len(self.record_type)) + buffer.append(len(self._record_type)) if self._flag_sr: - buffer.append(len(self.payload)) + buffer.append(len(self._payload)) else: - buffer.extend(len(self.payload).to_bytes(length=4, byteorder="big", signed=False)) + buffer.extend(len(self._payload).to_bytes(length=4, byteorder="big", signed=False)) if self._flag_il: - buffer.append(len(self.record_id)) + buffer.append(len(self._record_id)) - buffer.extend(self.record_type) - buffer.extend(self.record_id) - buffer.extend(self.payload) + buffer.extend(self._record_type) + buffer.extend(self._record_id) + buffer.extend(self._payload) return bytes(buffer) @@ -359,32 +364,35 @@ def _validate_tnf(tnf: NdefTNF, record_type: bytes, record_id: bytes, payload: b elif tnf == NdefTNF.UNCHANGED: raise ValueError("unexpected TNF_UNCHANGED in first chunk or logical record") - def __len__(self) -> int: - length = 3 + len(self.record_type) + len(self.record_id) + len(self.payload) + def bytes_size(self) -> int: + length = 3 + len(self._record_type) + len(self._record_id) + len(self._payload) if not self._flag_sr: length += 3 if self._flag_il: length += 1 return length + def __bytes__(self) -> bytes: + return self.to_bytes() + def __repr__(self) -> str: return ("NdefRecord(" - f"tnf=0x{self.tnf.value:02x}, " - f"type={self.record_type}, " - f"id={self.record_id}, " - f"payload={self.payload})") + f"tnf=0x{self._tnf.value:02x}, " + f"type={self._record_type}, " + f"id={self._record_id}, " + f"payload={self._payload})") def __eq__(self, __value) -> bool: if __value is None or not isinstance(__value, NdefRecord): return super().__eq__(__value) else: - return self.tnf == __value.tnf and \ - self.record_type == __value.record_type and \ - self.record_id == __value.record_id and \ - self.payload == __value.payload + return self._tnf == __value._tnf and \ + self._record_type == __value._record_type and \ + self._record_id == __value._record_id and \ + self._payload == __value._payload def __hash__(self) -> int: - return hash((self.tnf.value, self.record_type, self.record_id, self.payload)) + return hash((self._tnf.value, self._record_type, self._record_id, self._payload)) # /platform/frameworks/base/nfc/java/android/nfc/NdefMessage.java @@ -392,35 +400,45 @@ class NdefMessage: def __init__(self, *records: NdefRecord) -> None: if len(records) == 0: raise ValueError("must have at least one record") - self.records: tuple[NdefRecord, ...] = records + self._records: tuple[NdefRecord, ...] = records + + @property + def records(self) -> tuple[NdefRecord, ...]: + return self._records @staticmethod def parse(buffer: bytes) -> 'NdefMessage': records = NdefRecord.parse(buffer, False) - if len(records) == 0: - raise ValueError("must have at least one record") - else: - return NdefMessage(*records) + return NdefMessage(*records) def to_bytes(self) -> bytes: return b''.join( [ - record.to_bytes(i == 0, i == len(self.records) - 1) - for i, record in enumerate(self.records) + record.to_bytes(i == 0, i == len(self._records) - 1) + for i, record in enumerate(self._records) ] ) + def bytes_size(self) -> int: + return sum(i.bytes_size() for i in self._records) + + def __bytes__(self) -> bytes: + return self.to_bytes() + + def __iter__(self) -> Iterator[NdefRecord]: + return iter(self._records) + def __len__(self) -> int: - return sum([len(i) for i in self.records]) + return len(self._records) def __repr__(self) -> str: - return f"NdefMessage({', '.join([repr(i) for i in self.records])})" + return f"NdefMessage({', '.join([repr(i) for i in self._records])})" def __eq__(self, __value) -> bool: if __value is None or not isinstance(__value, NdefMessage): return super().__eq__(__value) else: - return self.records == __value.records + return self._records == __value._records def __hash__(self) -> int: - return hash(self.records) + return hash(self._records) diff --git a/tests/test_ndef.py b/tests/test_ndef.py index 64959cc..9daa9c5 100644 --- a/tests/test_ndef.py +++ b/tests/test_ndef.py @@ -4,6 +4,7 @@ # /cts/tests/tests/ndef/src/android/ndef/cts/NdefTest.java +# noinspection HttpUrlsUsage class NdefTestCase(unittest.TestCase): _PAYLOAD_255 = b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08" + \ b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08" + \ @@ -273,15 +274,24 @@ def test_to_bytes(self) -> None: # 256 byte payload self.assertEqual(b"\xc5\x00\x00\x00\x01\x00" + self._PAYLOAD_256, NdefMessage(NdefRecord(NdefTNF.UNKNOWN, None, None, self._PAYLOAD_256)).to_bytes()) + self.assertEqual( + b"\xd8\x00\x00\x00", + bytes(NdefRecord(NdefTNF.EMPTY, None, None, None)) + ) + self.assertEqual( + b"\xd8\x00\x00\x00", + bytes(NdefMessage(NdefRecord(NdefTNF.EMPTY, None, None, None))) + ) + def test_get_bytes_length(self) -> None: # single short record r = NdefRecord(NdefTNF.EMPTY, None, None, None) b = b"\xd8\x00\x00\x00" - self.assertEqual(len(b), len(NdefMessage(r))) + self.assertEqual(len(b), NdefMessage(r).bytes_size()) # 3 records r = NdefRecord(NdefTNF.EMPTY, None, None, None) b = b"\x98\x00\x00\x00\x18\x00\x00\x00\x58\x00\x00\x00" - self.assertEqual(len(b), len(NdefMessage(r, r, r))) + self.assertEqual(len(b), NdefMessage(r, r, r).bytes_size()) def test_to_uri(self) -> None: # absolute uri @@ -354,12 +364,12 @@ def test_repr(self) -> None: self.assertEqual(f"NdefMessage({p1})", repr(NdefMessage(r1))) self.assertEqual(f"NdefMessage({p1}, {p2})", repr(NdefMessage(r1, r2))) - def test_len(self) -> None: + def test_iter_message(self) -> None: r1 = NdefRecord(NdefTNF.EMPTY, None, None, None) r2 = NdefRecord(NdefTNF.EXTERNAL_TYPE, b"type", b"\x01", self._PAYLOAD_256) - self.assertEqual(4, len(r1)) - self.assertEqual(268, len(r2)) - self.assertEqual(4 + 268, len(NdefMessage(r1, r2))) + self.assertEqual(1, len(list(NdefMessage(r1)))) + self.assertEqual(2, len(list(NdefMessage(r1, r2)))) + self.assertEqual(NdefMessage(r1, r2).records, tuple(iter(NdefMessage(r1, r2)))) if __name__ == "__main__":