diff --git a/src/pyndef/ndef.py b/src/pyndef/ndef.py index 7f73ff8..f899854 100644 --- a/src/pyndef/ndef.py +++ b/src/pyndef/ndef.py @@ -1,4 +1,5 @@ import enum +import io from typing import List, Tuple, Optional, Union, Iterator from urllib.parse import urlparse @@ -26,9 +27,6 @@ class NdefRTD(bytes, enum.Enum): HANDOVER_SELECT = b"Hs" ANDROID_APP = b"android.com:pkg" - def __bytes__(self) -> bytes: - return self.value - # noinspection SpellCheckingInspection _URI_PREFIX_MAP: Tuple[str, ...] = ( @@ -84,6 +82,41 @@ def _normalize_uri_scheme(raw_uri: str) -> str: return parsed_uri.geturl() +class _BytesBuffer(io.BytesIO): + def __init__(self, buffer: Optional[bytes] = None, byte_order: str = "big") -> None: + super().__init__(buffer) + self._byte_order: str = byte_order + + def read_uint8(self) -> int: + return self.read_int(1, False) + + def read_uint32(self) -> int: + return self.read_int(4, False) + + def read_int(self, size: int, signed: bool) -> int: + buffer = self.read(size) + if 0 < size != len(buffer): + raise IndexError("buffer is too short") + # noinspection PyTypeChecker + return int.from_bytes(bytes=buffer, byteorder=self._byte_order, signed=signed) + + def read_bytes(self, size: int) -> bytes: + buffer = self.read(size) + if len(buffer) != size: + raise IndexError("buffer is too short") + return buffer + + def write_uint8(self, value: int) -> int: + return self.write_int(value, 1, False) + + def write_uint32(self, value: int) -> int: + return self.write_int(value, 4, False) + + def write_int(self, value: int, size: int, signed: bool) -> int: + # noinspection PyTypeChecker + return self.write(value.to_bytes(length=size, byteorder=self._byte_order, signed=signed)) + + # /platform/frameworks/base/nfc/java/android/nfc/NdefRecord.java class NdefRecord: _FLAG_MB: int = 0x80 @@ -196,16 +229,15 @@ def _ensure_sane_payload_size(size: int) -> None: record_type: Optional[bytes] = None record_id: Optional[bytes] = None + bytes_buffer = _BytesBuffer(buffer) chunks: List[bytes] = [] in_chunk: bool = False chunk_tnf: Optional[NdefTNF] = None me: bool = False - offset: int = 0 try: while not me: - flag = buffer[offset] - offset += 1 + flag = bytes_buffer.read_uint8() mb = flag & NdefRecord._FLAG_MB != 0 me = flag & NdefRecord._FLAG_ME != 0 @@ -227,32 +259,19 @@ def _ensure_sane_payload_size(size: int) -> None: elif not in_chunk and tnf == NdefTNF.UNCHANGED: raise ValueError("unexpected TNF_UNCHANGED in first chunk or not chunked record") - type_length = buffer[offset] - offset += 1 - if sr: - payload_length = buffer[offset] - offset += 1 - else: - payload_length = int.from_bytes(buffer[offset:offset + 4], byteorder="big", signed=False) - offset += 4 - if il: - id_length = buffer[offset] - offset += 1 - else: - id_length = 0 + type_length = bytes_buffer.read_uint8() + payload_length = bytes_buffer.read_uint8() if sr else bytes_buffer.read_uint32() + id_length = bytes_buffer.read_uint8() if il else 0 if in_chunk and type_length != 0: raise ValueError("expected zero-length type in non-leading chunk") if not in_chunk: - record_type = buffer[offset:offset + type_length] - offset += type_length - record_id = buffer[offset:offset + id_length] - offset += id_length + record_type = bytes_buffer.read_bytes(type_length) + record_id = bytes_buffer.read_bytes(id_length) _ensure_sane_payload_size(payload_length) - payload = buffer[offset:offset + payload_length] - offset += payload_length + payload = bytes_buffer.read_bytes(payload_length) if cf and not in_chunk: if type_length == 0 and tnf != NdefTNF.UNKNOWN: @@ -264,9 +283,7 @@ def _ensure_sane_payload_size(size: int) -> None: chunks.append(payload) if not cf and in_chunk: - payload_length = 0 - for chunk in chunks: - payload_length += len(chunk) + payload_length = sum(len(chunk) for chunk in chunks) _ensure_sane_payload_size(payload_length) payload = b''.join([chunk for chunk in chunks]) tnf = chunk_tnf @@ -285,7 +302,7 @@ def _ensure_sane_payload_size(size: int) -> None: except IndexError as e: raise ValueError("expected more data") from e - if offset < len(buffer): + if bytes_buffer.read(): raise ValueError("data too long") return tuple(records) @@ -330,28 +347,28 @@ def _to_uri(self, in_smart_poster: bool) -> Optional[str]: return None def to_bytes(self, flag_mb: bool = True, flag_me: bool = True) -> bytes: - buffer = bytearray() + buffer = _BytesBuffer() flag = (self._FLAG_MB if flag_mb else 0) | \ (self._FLAG_ME if flag_me else 0) | \ (self._FLAG_SR if self._flag_sr else 0) | \ (self._FLAG_IL if self._flag_il else 0) | self._tnf.value - buffer.append(flag) - buffer.append(len(self._record_type)) + buffer.write_uint8(flag) + buffer.write_uint8(len(self._record_type)) if self._flag_sr: - buffer.append(len(self._payload)) + buffer.write_uint8(len(self._payload)) else: - buffer.extend(len(self._payload).to_bytes(length=4, byteorder="big", signed=False)) + buffer.write_uint32(len(self._payload)) if self._flag_il: - buffer.append(len(self._record_id)) + buffer.write_uint8(len(self._record_id)) - buffer.extend(self._record_type) - buffer.extend(self._record_id) - buffer.extend(self._payload) + buffer.write(self._record_type) + buffer.write(self._record_id) + buffer.write(self._payload) - return bytes(buffer) + return buffer.getvalue() @staticmethod def _validate_tnf(tnf: NdefTNF, record_type: bytes, record_id: bytes, payload: bytes) -> None: