Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added an API for working with wheel files #805

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Prev Previous commit
Lots of bug fixes, refactorings and coverage improvements
  • Loading branch information
agronholm committed Aug 2, 2024
commit 50a8631e4b3e71ef8fa7952387b0037ef2624f00
77 changes: 43 additions & 34 deletions src/packaging/wheelfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class WheelContentElement(NamedTuple):
size: int
stream: IO[bytes]

def __repr__(self) -> str:
return f"{self.__class__.__name__}({str(self.path)!r}, size={self.size!r})"


def _encode_hash_value(hash_value: bytes) -> str:
return urlsafe_b64encode(hash_value).rstrip(b"=").decode("ascii")
Expand Down Expand Up @@ -95,21 +98,28 @@ def __init__(

def read(self, amount: int = -1) -> bytes:
data = self._fp.read(amount)
if amount and self._record_entry is not None:
if data:
self._hash.update(data)
self._num_bytes_read += len(data)
elif self._record_entry:
# The file has been read in full – check that hash and file size match
# with the entry in RECORD
if self._hash.digest() != self._record_entry.hash_value:
raise WheelError(f"Hash mismatch for file {self._arcname!r}")
elif self._num_bytes_read != self._record_entry.filesize:
raise WheelError(
f"{self._arcname}: file size mismatch: "
f"{self._record_entry.filesize} bytes in RECORD, "
f"{self._num_bytes_read} bytes in archive"
)
if self._record_entry is None:
return data

if data:
self._hash.update(data)
self._num_bytes_read += len(data)

if amount < 0 or len(data) < amount:
# The file has been read in full – check that hash and file size match
# with the entry in RECORD
if self._num_bytes_read != self._record_entry.filesize:
raise WheelError(
f"{self._arcname}: file size mismatch: "
f"{self._record_entry.filesize} bytes in RECORD, "
f"{self._num_bytes_read} bytes in archive"
)
elif self._hash.digest() != self._record_entry.hash_value:
raise WheelError(
f"{self._arcname}: hash mismatch: "
f"{self._record_entry.hash_value.hex()} in RECORD, "
f"{self._hash.hexdigest()} in archive"
)

return data

Expand All @@ -125,7 +135,7 @@ def __exit__(
self._fp.close()

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._fp!r}, {self._arcname!r})"
return f"{self.__class__.__name__}({self._arcname!r})"


class WheelReader:
Expand Down Expand Up @@ -270,7 +280,7 @@ def read_dist_info(self, filename: str) -> str:

return contents.decode("utf-8")

def get_contents(self) -> Iterator[WheelContentElement]:
def iterate_contents(self) -> Iterator[WheelContentElement]:
for fname, entry in self._record_entries.items():
with self._zip.open(fname, "r") as stream:
yield WheelContentElement(
Expand All @@ -285,17 +295,10 @@ def validate_record(self) -> None:
if basename in _exclude_filenames:
continue

try:
record = self._record_entries[zinfo.filename]
except KeyError:
raise WheelError(f"No hash found for file {zinfo.filename!r}") from None

hash_ = hashlib.new(record.hash_algorithm)
with self._zip.open(zinfo) as fp:
hash_.update(fp.read(65536))

if hash_.digest() != record.hash_value:
raise WheelError(f"Hash mismatch for file {zinfo.filename!r}")
with self.open(zinfo.filename) as fp:
while True:
if not fp.read(65536):
break

def extractall(self, base_path: str | PathLike[str]) -> None:
basedir = Path(base_path)
Expand All @@ -307,27 +310,30 @@ def extractall(self, base_path: str | PathLike[str]) -> None:
for fname in self._zip.namelist():
target_path = basedir.joinpath(fname)
target_path.parent.mkdir(0o755, True, True)
with self._open_file(fname) as infile, target_path.open("wb") as outfile:
with self.open(fname) as infile, target_path.open("wb") as outfile:
while True:
data = infile.read(65536)
if not data:
break

outfile.write(data)

def _open_file(self, archive_name: str) -> WheelArchiveFile:
def open(self, archive_name: str) -> WheelArchiveFile:
basename = os.path.basename(archive_name)
if basename in _exclude_filenames:
record_entry = None
else:
record_entry = self._record_entries[archive_name]
try:
record_entry = self._record_entries[archive_name]
except KeyError:
raise WheelError(f"No hash found for file {archive_name!r}") from None

return WheelArchiveFile(
self._zip.open(archive_name), archive_name, record_entry
)

def read_file(self, archive_name: str) -> bytes:
with self._open_file(archive_name) as fp:
with self.open(archive_name) as fp:
return fp.read()

def read_data_file(self, filename: str) -> bytes:
Expand Down Expand Up @@ -446,7 +452,7 @@ def write_metadata(self, items: Iterable[tuple[str, str]]) -> None:
for key, value in items:
key = key.title()
if key == "Description":
msg.set_payload(value, "utf-8")
msg.set_payload(value.encode("utf-8"))
else:
msg.add_header(key, value)

Expand Down Expand Up @@ -541,4 +547,7 @@ def write_distinfo_file(
self.write_file(archive_path, contents, timestamp=timestamp)

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.path_or_fd!r})"
return (
f"{self.__class__.__name__}({self.path_or_fd}, "
f"generator={self.generator!r})"
)
Loading
Loading