Skip to content

Commit

Permalink
Improve bytes buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
XFY9326 committed Feb 27, 2024
1 parent 45060fb commit aa78dbc
Showing 1 changed file with 56 additions and 39 deletions.
95 changes: 56 additions & 39 deletions src/pyndef/ndef.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
import io
from typing import List, Tuple, Optional, Union, Iterator
from urllib.parse import urlparse

Expand Down Expand Up @@ -26,9 +27,6 @@ class NdefRTD(bytes, enum.Enum):
HANDOVER_SELECT = b"Hs"
ANDROID_APP = b"android.com:pkg"

def __bytes__(self) -> bytes:
return self.value


# noinspection SpellCheckingInspection
_URI_PREFIX_MAP: Tuple[str, ...] = (
Expand Down Expand Up @@ -84,6 +82,41 @@ def _normalize_uri_scheme(raw_uri: str) -> str:
return parsed_uri.geturl()


class _BytesBuffer(io.BytesIO):
def __init__(self, buffer: Optional[bytes] = None, byte_order: str = "big") -> None:
super().__init__(buffer)
self._byte_order: str = byte_order

def read_uint8(self) -> int:
return self.read_int(1, False)

def read_uint32(self) -> int:
return self.read_int(4, False)

def read_int(self, size: int, signed: bool) -> int:
buffer = self.read(size)
if 0 < size != len(buffer):
raise IndexError("buffer is too short")
# noinspection PyTypeChecker
return int.from_bytes(bytes=buffer, byteorder=self._byte_order, signed=signed)

def read_bytes(self, size: int) -> bytes:
buffer = self.read(size)
if len(buffer) != size:
raise IndexError("buffer is too short")
return buffer

def write_uint8(self, value: int) -> int:
return self.write_int(value, 1, False)

def write_uint32(self, value: int) -> int:
return self.write_int(value, 4, False)

def write_int(self, value: int, size: int, signed: bool) -> int:
# noinspection PyTypeChecker
return self.write(value.to_bytes(length=size, byteorder=self._byte_order, signed=signed))


# /platform/frameworks/base/nfc/java/android/nfc/NdefRecord.java
class NdefRecord:
_FLAG_MB: int = 0x80
Expand Down Expand Up @@ -196,16 +229,15 @@ def _ensure_sane_payload_size(size: int) -> None:
record_type: Optional[bytes] = None
record_id: Optional[bytes] = None

bytes_buffer = _BytesBuffer(buffer)
chunks: List[bytes] = []
in_chunk: bool = False
chunk_tnf: Optional[NdefTNF] = None
me: bool = False
offset: int = 0

try:
while not me:
flag = buffer[offset]
offset += 1
flag = bytes_buffer.read_uint8()

mb = flag & NdefRecord._FLAG_MB != 0
me = flag & NdefRecord._FLAG_ME != 0
Expand All @@ -227,32 +259,19 @@ def _ensure_sane_payload_size(size: int) -> None:
elif not in_chunk and tnf == NdefTNF.UNCHANGED:
raise ValueError("unexpected TNF_UNCHANGED in first chunk or not chunked record")

type_length = buffer[offset]
offset += 1
if sr:
payload_length = buffer[offset]
offset += 1
else:
payload_length = int.from_bytes(buffer[offset:offset + 4], byteorder="big", signed=False)
offset += 4
if il:
id_length = buffer[offset]
offset += 1
else:
id_length = 0
type_length = bytes_buffer.read_uint8()
payload_length = bytes_buffer.read_uint8() if sr else bytes_buffer.read_uint32()
id_length = bytes_buffer.read_uint8() if il else 0

if in_chunk and type_length != 0:
raise ValueError("expected zero-length type in non-leading chunk")

if not in_chunk:
record_type = buffer[offset:offset + type_length]
offset += type_length
record_id = buffer[offset:offset + id_length]
offset += id_length
record_type = bytes_buffer.read_bytes(type_length)
record_id = bytes_buffer.read_bytes(id_length)

_ensure_sane_payload_size(payload_length)
payload = buffer[offset:offset + payload_length]
offset += payload_length
payload = bytes_buffer.read_bytes(payload_length)

if cf and not in_chunk:
if type_length == 0 and tnf != NdefTNF.UNKNOWN:
Expand All @@ -264,9 +283,7 @@ def _ensure_sane_payload_size(size: int) -> None:
chunks.append(payload)

if not cf and in_chunk:
payload_length = 0
for chunk in chunks:
payload_length += len(chunk)
payload_length = sum(len(chunk) for chunk in chunks)
_ensure_sane_payload_size(payload_length)
payload = b''.join([chunk for chunk in chunks])
tnf = chunk_tnf
Expand All @@ -285,7 +302,7 @@ def _ensure_sane_payload_size(size: int) -> None:
except IndexError as e:
raise ValueError("expected more data") from e

if offset < len(buffer):
if bytes_buffer.read():
raise ValueError("data too long")

return tuple(records)
Expand Down Expand Up @@ -330,28 +347,28 @@ def _to_uri(self, in_smart_poster: bool) -> Optional[str]:
return None

def to_bytes(self, flag_mb: bool = True, flag_me: bool = True) -> bytes:
buffer = bytearray()
buffer = _BytesBuffer()

flag = (self._FLAG_MB if flag_mb else 0) | \
(self._FLAG_ME if flag_me else 0) | \
(self._FLAG_SR if self._flag_sr else 0) | \
(self._FLAG_IL if self._flag_il else 0) | self._tnf.value

buffer.append(flag)
buffer.append(len(self._record_type))
buffer.write_uint8(flag)
buffer.write_uint8(len(self._record_type))

if self._flag_sr:
buffer.append(len(self._payload))
buffer.write_uint8(len(self._payload))
else:
buffer.extend(len(self._payload).to_bytes(length=4, byteorder="big", signed=False))
buffer.write_uint32(len(self._payload))
if self._flag_il:
buffer.append(len(self._record_id))
buffer.write_uint8(len(self._record_id))

buffer.extend(self._record_type)
buffer.extend(self._record_id)
buffer.extend(self._payload)
buffer.write(self._record_type)
buffer.write(self._record_id)
buffer.write(self._payload)

return bytes(buffer)
return buffer.getvalue()

@staticmethod
def _validate_tnf(tnf: NdefTNF, record_type: bytes, record_id: bytes, payload: bytes) -> None:
Expand Down

0 comments on commit aa78dbc

Please sign in to comment.