Skip to content

Commit

Permalink
Improve code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
XFY9326 committed Feb 26, 2024
1 parent a8ac9ad commit a9605a4
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 74 deletions.
154 changes: 86 additions & 68 deletions src/pyndef/ndef.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import enum
from typing import List, Tuple, Optional, Union
from typing import List, Tuple, Optional, Union, Iterator
from urllib.parse import urlparse


Expand Down Expand Up @@ -95,16 +95,32 @@ class NdefRecord:
_MAX_PAYLOAD_SIZE: int = 10 * (1 << 20)

def __init__(self, tnf: NdefTNF, record_type: Union[NdefRTD, bytes, None], record_id: Optional[bytes], payload: Optional[bytes]) -> None:
self.tnf: NdefTNF = tnf
self.record_type: bytes = bytes(record_type) if record_type is not None else bytes()
self.record_id: bytes = bytes(record_id) if record_id is not None else bytes()
self.payload: bytes = bytes(payload) if payload is not None else bytes()
self._tnf: NdefTNF = tnf
self._record_type: bytes = bytes(record_type) if record_type is not None else bytes()
self._record_id: bytes = bytes(record_id) if record_id is not None else bytes()
self._payload: bytes = bytes(payload) if payload is not None else bytes()

self._validate_tnf(tnf, self.record_type, self.record_id, self.payload)
self._validate_tnf(tnf, self._record_type, self._record_id, self._payload)

@property
def tnf(self) -> NdefTNF:
return self._tnf

@property
def record_type(self) -> bytes:
return self._record_type

@property
def record_id(self) -> bytes:
return self._record_id

@property
def payload(self) -> bytes:
return self._payload

def to_known_rtd(self) -> Optional[NdefRTD]:
try:
return NdefRTD(self.record_type)
return NdefRTD(self._record_type)
except ValueError:
return None

Expand Down Expand Up @@ -196,11 +212,7 @@ def _ensure_sane_payload_size(size: int) -> None:
cf = flag & NdefRecord._FLAG_CF != 0
sr = flag & NdefRecord._FLAG_SR != 0
il = flag & NdefRecord._FLAG_IL != 0
raw_tnf = flag & 0x07
try:
tnf = NdefTNF(raw_tnf)
except ValueError as e:
raise ValueError(f"unexpected tnf value: 0x{raw_tnf:02x}") from e
tnf = NdefTNF(flag & 0x07)

if not mb and len(records) == 0 and not in_chunk and not ignore_mb_me:
raise ValueError("expected MB flag")
Expand Down Expand Up @@ -257,22 +269,15 @@ def _ensure_sane_payload_size(size: int) -> None:
payload_length += len(chunk)
_ensure_sane_payload_size(payload_length)
payload = b''.join([chunk for chunk in chunks])
if chunk_tnf is not None:
tnf = chunk_tnf
else:
raise ValueError("unknown chunk tnf")
tnf = chunk_tnf

if cf:
in_chunk = True
continue
else:
in_chunk = False

if record_type is not None and record_id is not None:
NdefRecord._validate_tnf(tnf, record_type, record_id, payload)
else:
raise ValueError("unknown record type or record id")

NdefRecord._validate_tnf(tnf, record_type, record_id, payload)
records.append(NdefRecord(tnf, record_type, record_id, payload))

if ignore_mb_me:
Expand All @@ -287,41 +292,41 @@ def _ensure_sane_payload_size(size: int) -> None:

@property
def _flag_sr(self) -> bool:
return len(self.payload) < 256
return len(self._payload) < 256

@property
def _flag_il(self) -> bool:
return True if self.tnf == NdefTNF.EMPTY else len(self.record_id) > 0
return True if self._tnf == NdefTNF.EMPTY else len(self._record_id) > 0

def to_mime_type(self) -> Optional[str]:
if self.tnf == NdefTNF.WELL_KNOWN:
if self.record_type == NdefRTD.TEXT:
if self._tnf == NdefTNF.WELL_KNOWN:
if self._record_type == NdefRTD.TEXT:
return "text/plain"
elif self.tnf == NdefTNF.MIME_MEDIA:
raw_mime_type = self.record_type.decode("ascii")
elif self._tnf == NdefTNF.MIME_MEDIA:
raw_mime_type = self._record_type.decode("ascii")
return _normalize_mime_type(raw_mime_type)
return None

def to_uri(self) -> Optional[str]:
return self._to_uri(False)

def _to_uri(self, in_smart_poster: bool) -> Optional[str]:
if self.tnf == NdefTNF.WELL_KNOWN:
if self.record_type == NdefRTD.SMART_POSTER and not in_smart_poster:
for record in NdefMessage.parse(self.payload).records:
if self._tnf == NdefTNF.WELL_KNOWN:
if self._record_type == NdefRTD.SMART_POSTER and not in_smart_poster:
for record in NdefMessage.parse(self._payload):
uri = record._to_uri(True)
if uri is not None:
return _normalize_uri_scheme(uri)
elif self.record_type == NdefRTD.URI:
if len(self.payload) >= 2:
prefix_index = self.payload[0]
elif self._record_type == NdefRTD.URI:
if len(self._payload) >= 2:
prefix_index = self._payload[0]
if 0 <= prefix_index < len(_URI_PREFIX_MAP):
return _URI_PREFIX_MAP[prefix_index] + self.payload[1:].decode("utf-8")
elif self.tnf == NdefTNF.ABSOLUTE_URI:
return _normalize_uri_scheme(self.record_type.decode("utf-8"))
elif self.tnf == NdefTNF.EXTERNAL_TYPE:
return _URI_PREFIX_MAP[prefix_index] + self._payload[1:].decode("utf-8")
elif self._tnf == NdefTNF.ABSOLUTE_URI:
return _normalize_uri_scheme(self._record_type.decode("utf-8"))
elif self._tnf == NdefTNF.EXTERNAL_TYPE:
if not in_smart_poster:
return "vnd.android.nfc://ext/" + self.record_type.decode("ascii")
return "vnd.android.nfc://ext/" + self._record_type.decode("ascii")
return None

def to_bytes(self, flag_mb: bool = True, flag_me: bool = True) -> bytes:
Expand All @@ -330,21 +335,21 @@ def to_bytes(self, flag_mb: bool = True, flag_me: bool = True) -> bytes:
flag = (self._FLAG_MB if flag_mb else 0) | \
(self._FLAG_ME if flag_me else 0) | \
(self._FLAG_SR if self._flag_sr else 0) | \
(self._FLAG_IL if self._flag_il else 0) | self.tnf.value
(self._FLAG_IL if self._flag_il else 0) | self._tnf.value

buffer.append(flag)
buffer.append(len(self.record_type))
buffer.append(len(self._record_type))

if self._flag_sr:
buffer.append(len(self.payload))
buffer.append(len(self._payload))
else:
buffer.extend(len(self.payload).to_bytes(length=4, byteorder="big", signed=False))
buffer.extend(len(self._payload).to_bytes(length=4, byteorder="big", signed=False))
if self._flag_il:
buffer.append(len(self.record_id))
buffer.append(len(self._record_id))

buffer.extend(self.record_type)
buffer.extend(self.record_id)
buffer.extend(self.payload)
buffer.extend(self._record_type)
buffer.extend(self._record_id)
buffer.extend(self._payload)

return bytes(buffer)

Expand All @@ -359,68 +364,81 @@ def _validate_tnf(tnf: NdefTNF, record_type: bytes, record_id: bytes, payload: b
elif tnf == NdefTNF.UNCHANGED:
raise ValueError("unexpected TNF_UNCHANGED in first chunk or logical record")

def __len__(self) -> int:
length = 3 + len(self.record_type) + len(self.record_id) + len(self.payload)
def bytes_size(self) -> int:
length = 3 + len(self._record_type) + len(self._record_id) + len(self._payload)
if not self._flag_sr:
length += 3
if self._flag_il:
length += 1
return length

def __bytes__(self) -> bytes:
return self.to_bytes()

def __repr__(self) -> str:
return ("NdefRecord("
f"tnf=0x{self.tnf.value:02x}, "
f"type={self.record_type}, "
f"id={self.record_id}, "
f"payload={self.payload})")
f"tnf=0x{self._tnf.value:02x}, "
f"type={self._record_type}, "
f"id={self._record_id}, "
f"payload={self._payload})")

def __eq__(self, __value) -> bool:
if __value is None or not isinstance(__value, NdefRecord):
return super().__eq__(__value)
else:
return self.tnf == __value.tnf and \
self.record_type == __value.record_type and \
self.record_id == __value.record_id and \
self.payload == __value.payload
return self._tnf == __value._tnf and \
self._record_type == __value._record_type and \
self._record_id == __value._record_id and \
self._payload == __value._payload

def __hash__(self) -> int:
return hash((self.tnf.value, self.record_type, self.record_id, self.payload))
return hash((self._tnf.value, self._record_type, self._record_id, self._payload))


# /platform/frameworks/base/nfc/java/android/nfc/NdefMessage.java
class NdefMessage:
def __init__(self, *records: NdefRecord) -> None:
if len(records) == 0:
raise ValueError("must have at least one record")
self.records: tuple[NdefRecord, ...] = records
self._records: tuple[NdefRecord, ...] = records

@property
def records(self) -> tuple[NdefRecord, ...]:
return self._records

@staticmethod
def parse(buffer: bytes) -> 'NdefMessage':
records = NdefRecord.parse(buffer, False)
if len(records) == 0:
raise ValueError("must have at least one record")
else:
return NdefMessage(*records)
return NdefMessage(*records)

def to_bytes(self) -> bytes:
return b''.join(
[
record.to_bytes(i == 0, i == len(self.records) - 1)
for i, record in enumerate(self.records)
record.to_bytes(i == 0, i == len(self._records) - 1)
for i, record in enumerate(self._records)
]
)

def bytes_size(self) -> int:
return sum(i.bytes_size() for i in self._records)

def __bytes__(self) -> bytes:
return self.to_bytes()

def __iter__(self) -> Iterator[NdefRecord]:
return iter(self._records)

def __len__(self) -> int:
return sum([len(i) for i in self.records])
return len(self._records)

def __repr__(self) -> str:
return f"NdefMessage({', '.join([repr(i) for i in self.records])})"
return f"NdefMessage({', '.join([repr(i) for i in self._records])})"

def __eq__(self, __value) -> bool:
if __value is None or not isinstance(__value, NdefMessage):
return super().__eq__(__value)
else:
return self.records == __value.records
return self._records == __value._records

def __hash__(self) -> int:
return hash(self.records)
return hash(self._records)
22 changes: 16 additions & 6 deletions tests/test_ndef.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


# /cts/tests/tests/ndef/src/android/ndef/cts/NdefTest.java
# noinspection HttpUrlsUsage
class NdefTestCase(unittest.TestCase):
_PAYLOAD_255 = b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08" + \
b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08" + \
Expand Down Expand Up @@ -273,15 +274,24 @@ def test_to_bytes(self) -> None:
# 256 byte payload
self.assertEqual(b"\xc5\x00\x00\x00\x01\x00" + self._PAYLOAD_256, NdefMessage(NdefRecord(NdefTNF.UNKNOWN, None, None, self._PAYLOAD_256)).to_bytes())

self.assertEqual(
b"\xd8\x00\x00\x00",
bytes(NdefRecord(NdefTNF.EMPTY, None, None, None))
)
self.assertEqual(
b"\xd8\x00\x00\x00",
bytes(NdefMessage(NdefRecord(NdefTNF.EMPTY, None, None, None)))
)

def test_get_bytes_length(self) -> None:
# single short record
r = NdefRecord(NdefTNF.EMPTY, None, None, None)
b = b"\xd8\x00\x00\x00"
self.assertEqual(len(b), len(NdefMessage(r)))
self.assertEqual(len(b), NdefMessage(r).bytes_size())
# 3 records
r = NdefRecord(NdefTNF.EMPTY, None, None, None)
b = b"\x98\x00\x00\x00\x18\x00\x00\x00\x58\x00\x00\x00"
self.assertEqual(len(b), len(NdefMessage(r, r, r)))
self.assertEqual(len(b), NdefMessage(r, r, r).bytes_size())

def test_to_uri(self) -> None:
# absolute uri
Expand Down Expand Up @@ -354,12 +364,12 @@ def test_repr(self) -> None:
self.assertEqual(f"NdefMessage({p1})", repr(NdefMessage(r1)))
self.assertEqual(f"NdefMessage({p1}, {p2})", repr(NdefMessage(r1, r2)))

def test_len(self) -> None:
def test_iter_message(self) -> None:
r1 = NdefRecord(NdefTNF.EMPTY, None, None, None)
r2 = NdefRecord(NdefTNF.EXTERNAL_TYPE, b"type", b"\x01", self._PAYLOAD_256)
self.assertEqual(4, len(r1))
self.assertEqual(268, len(r2))
self.assertEqual(4 + 268, len(NdefMessage(r1, r2)))
self.assertEqual(1, len(list(NdefMessage(r1))))
self.assertEqual(2, len(list(NdefMessage(r1, r2))))
self.assertEqual(NdefMessage(r1, r2).records, tuple(iter(NdefMessage(r1, r2))))


if __name__ == "__main__":
Expand Down

0 comments on commit a9605a4

Please sign in to comment.