Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions src/sparsezoo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
save_outputs_to_tar,
)
from sparsezoo.objects import (
AliasedSelectDirectory,
Directory,
File,
NumpyDirectory,
Expand Down Expand Up @@ -138,12 +137,26 @@ def __init__(self, source: str, download_path: Optional[str] = None):
files, directory_class=Directory, display_name="sample-labels"
)

self.deployment: AliasedSelectDirectory = self._directory_from_files(
self.deployment: SelectDirectory = self._directory_from_files(
files,
directory_class=AliasedSelectDirectory,
directory_class=SelectDirectory,
display_name="deployment",
download_alias="deployment.tar.gz",
stub_params=self.stub_params,
allow_multiple_outputs=True,
)

if isinstance(self.deployment, list):
# if there are multiple deployment directories
# (this may happen due to the presence of both
# - deployment directory
# - deployment.tar.gz file
# we need to choose one (they are identical)
self.deployment = self.deployment[0]

self.deployment_tar: SelectDirectory = self._directory_from_files(
files,
directory_class=SelectDirectory,
display_name="deployment.tar.gz",
)

self.onnx_folder: Directory = self._directory_from_files(
Expand Down Expand Up @@ -196,6 +209,7 @@ def __init__(self, source: str, download_path: Optional[str] = None):
self._files_dictionary = {
"training": self.training,
"deployment": self.deployment,
"deployment.tar.gz": self.deployment_tar,
"onnx_folder": self.onnx_folder,
"logs": self.logs,
"sample_originals": self.sample_originals,
Expand Down Expand Up @@ -233,9 +247,9 @@ def deployment_directory_path(self) -> str:
deployment directory if compressed
"""
# trigger initial download if not downloaded
self.deployment.path
if self.deployment.is_archive:
self.deployment.unzip()
self.deployment_tar.path
if self.deployment_tar.is_archive:
self.deployment_tar.unzip()

return self.deployment.path

Expand Down Expand Up @@ -310,6 +324,12 @@ def download(
else:
downloads = []
for key, file in self._files_dictionary.items():
if key == "deployment":
# skip the download of the deployment directory
# since identical files will be downloaded
# in the deployment_tar
_LOGGER.debug(f"Intentionally skipping downloading the file {key}")
continue
if file is not None:
# save all the files to a temporary directory
downloads.append(self._download(file, download_path))
Expand Down
15 changes: 14 additions & 1 deletion src/sparsezoo/objects/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,22 @@ def unzip(self, extract_directory: Optional[str] = None, force: bool = False):
member.name = os.path.basename(member.name)
tar.extract(member=member, path=path)
files.append(
File(name=member.name, path=os.path.join(path, member.name))
File(
name=member.name,
path=os.path.join(path, member.name),
parent_directory=path,
)
)
tar.close()
# if path already exists, then the tar archive has already been unzipped
# and we can just use the files in the directory
elif os.path.exists(path):
for file in os.listdir(path):
files.append(
File(
name=file, path=os.path.join(path, file), parent_directory=path
)
)

self.name = name
self.files = files
Expand Down
56 changes: 56 additions & 0 deletions tests/sparsezoo/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
import pytest

from sparsezoo import Model
from sparsezoo.objects.directories import SelectDirectory


files_ic = {
"training",
"deployment.tar.gz",
"deployment",
"logs",
"onnx",
Expand Down Expand Up @@ -182,6 +184,10 @@ def setup(self, stub, clone_sample_outputs, expected_files):
temp_dir = tempfile.TemporaryDirectory(dir="/tmp")
model = Model(stub, temp_dir.name)
model.download()
# since downloading the `deployment` file is
# disabled by default, we need to do it
# explicitly
model.deployment.download()
self._add_mock_files(temp_dir.name, clone_sample_outputs=clone_sample_outputs)
model = Model(temp_dir.name)

Expand Down Expand Up @@ -329,6 +335,56 @@ def test_model_gz_extraction_from_local_files(stub: str):
shutil.rmtree(temp_dir.name)


@pytest.mark.parametrize(
"stub",
[
"zoo:cv/classification/mobilenet_v1-1.0/pytorch/sparseml/"
"imagenet/pruned-moderate",
],
)
def test_model_deployment_directory(stub):
temp_dir = tempfile.TemporaryDirectory(dir="/tmp")
expected_deployment_files = ["model.onnx"]

model = Model(stub, temp_dir.name)
assert model.deployment_tar.is_archive
# download and extract deployment tar
deployment_dir_path = model.deployment_directory_path

# deployment and deployment_tar should be point to the same files
assert deployment_dir_path == model.deployment_tar.path == model.deployment.path
# make sure that the model contains expected files
assert set(os.listdir(temp_dir.name)) == {"deployment.tar.gz", "deployment"}
assert (
os.listdir(os.path.join(temp_dir.name, "deployment"))
== expected_deployment_files
)

assert isinstance(model.deployment, SelectDirectory)
# TODO: this should be 1. However, the API is returning for `deployment` file type
# both `model.onnx` and `deployment/model.onnx`.
# This should probably be fixed on the API side
assert (
len(model.deployment.files) == 2
) # should be == len(expected_deployment_files)

assert isinstance(model.deployment_tar, SelectDirectory)
assert len(model.deployment_tar.files) == len(expected_deployment_files)
assert not model.deployment_tar.is_archive

# test recreating the model from the local files
model = Model(temp_dir.name)

assert isinstance(model.deployment, SelectDirectory)
assert len(model.deployment.files) == len(expected_deployment_files)

assert isinstance(model.deployment_tar, SelectDirectory)
assert len(model.deployment_tar.files) == len(expected_deployment_files)
assert not model.deployment_tar.is_archive

shutil.rmtree(temp_dir.name)


def _extraction_test_helper(model: Model):
# download and extract model.onnx.tar.gz
# path should point to extracted model.onnx file
Expand Down