Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit cb3c1fe

Browse files
Clean up relationship between deployment_tar and deployment (#389)
* [WIP] clean up relationship between deployment_tar and deployment * note that deployment_directory_path can be removed * potentially working solution * chaotic but working commit * tests pass * fix typo * adress PR comments --------- Co-authored-by: Damian <damian@neuralmagic.com> Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
1 parent a0e0946 commit cb3c1fe

File tree

6 files changed

+66
-100
lines changed

6 files changed

+66
-100
lines changed

src/sparsezoo/model/model.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,14 @@ def __init__(self, source: str, download_path: Optional[str] = None):
111111
self.sample_originals: Directory = self._directory_from_files(
112112
files,
113113
directory_class=Directory,
114+
allow_multiple_outputs=True,
114115
display_name="sample-originals",
115116
)
116117
self.sample_inputs: NumpyDirectory = self._directory_from_files(
117118
files,
118119
directory_class=NumpyDirectory,
119120
display_name="sample-inputs",
121+
allow_multiple_outputs=True,
120122
)
121123

122124
self.model_card: File = self._file_from_files(files, display_name="model.md")
@@ -134,30 +136,25 @@ def __init__(self, source: str, download_path: Optional[str] = None):
134136
] = self._sample_outputs_list_to_dict(self.sample_outputs)
135137

136138
self.sample_labels: Directory = self._directory_from_files(
137-
files, directory_class=Directory, display_name="sample-labels"
138-
)
139-
140-
self.deployment: SelectDirectory = self._directory_from_files(
141139
files,
142-
directory_class=SelectDirectory,
143-
display_name="deployment",
144-
stub_params=self.stub_params,
140+
directory_class=Directory,
145141
allow_multiple_outputs=True,
142+
display_name="sample-labels",
146143
)
147144

148-
if isinstance(self.deployment, list):
149-
# if there are multiple deployment directories
150-
# (this may happen due to the presence of both
151-
# - deployment directory
152-
# - deployment.tar.gz file
153-
# we need to choose one (they are identical)
154-
self.deployment = self.deployment[0]
155-
156145
self.deployment_tar: SelectDirectory = self._directory_from_files(
157146
files,
158147
directory_class=SelectDirectory,
159148
display_name="deployment.tar.gz",
160149
)
150+
self.deployment: SelectDirectory = self._directory_from_files(
151+
files,
152+
directory_class=SelectDirectory,
153+
display_name="deployment",
154+
stub_params=self.stub_params,
155+
allow_multiple_outputs=True,
156+
tar_directory=self.deployment_tar,
157+
)
161158

162159
self.onnx_folder: Directory = self._directory_from_files(
163160
files,
@@ -194,6 +191,30 @@ def __init__(self, source: str, download_path: Optional[str] = None):
194191
# compressed file size on disk in bytes
195192
self.compressed_size: Optional[int] = compressed_size
196193

194+
# if there are multiple deployment directories
195+
# (this may happen due to the presence of both e.g.:
196+
# - deployment directory
197+
# - deployment.tar.gz file
198+
# we need to choose one (they are identical at this point)
199+
self.sample_originals = (
200+
self.sample_originals[0]
201+
if isinstance(self.sample_originals, list)
202+
else self.sample_originals
203+
)
204+
self.sample_inputs = (
205+
self.sample_inputs[0]
206+
if isinstance(self.sample_inputs, list)
207+
else self.sample_inputs
208+
)
209+
self.sample_labels = (
210+
self.sample_labels[0]
211+
if isinstance(self.sample_labels, list)
212+
else self.sample_labels
213+
)
214+
self.deployment = (
215+
self.deployment[0] if isinstance(self.deployment, list) else self.deployment
216+
)
217+
197218
# sorting name of `sample_inputs` and `sample_output` files,
198219
# so that they have same one-to-one correspondence when we jointly
199220
# iterate over them
@@ -209,7 +230,6 @@ def __init__(self, source: str, download_path: Optional[str] = None):
209230
self._files_dictionary = {
210231
"training": self.training,
211232
"deployment": self.deployment,
212-
"deployment.tar.gz": self.deployment_tar,
213233
"onnx_folder": self.onnx_folder,
214234
"logs": self.logs,
215235
"sample_originals": self.sample_originals,
@@ -239,20 +259,6 @@ def __init__(self, source: str, download_path: Optional[str] = None):
239259

240260
self.integration_validator = IntegrationValidator(model=self)
241261

242-
@property
243-
def deployment_directory_path(self) -> str:
244-
"""
245-
:return: file path of uncompressed deployemnt directory. Both (1) downloads
246-
compressed deployemnent directory if not downloaded (2) uncompresses
247-
deployment directory if compressed
248-
"""
249-
# trigger initial download if not downloaded
250-
self.deployment_tar.path
251-
if self.deployment_tar.is_archive:
252-
self.deployment_tar.unzip()
253-
254-
return self.deployment.path
255-
256262
@property
257263
def stub_params(self) -> Dict[str, str]:
258264
"""
@@ -324,12 +330,6 @@ def download(
324330
else:
325331
downloads = []
326332
for key, file in self._files_dictionary.items():
327-
if key == "deployment":
328-
# skip the download of the deployment directory
329-
# since identical files will be downloaded
330-
# in the deployment_tar
331-
_LOGGER.debug(f"Intentionally skipping downloading the file {key}")
332-
continue
333333
if file is not None:
334334
# save all the files to a temporary directory
335335
downloads.append(self._download(file, download_path))
@@ -636,8 +636,8 @@ def _directory_from_files(
636636
files: List[Dict[str, Any]],
637637
directory_class: Union[Directory, NumpyDirectory] = Directory,
638638
display_name: Optional[str] = None,
639-
regex: Optional[bool] = False,
640-
allow_multiple_outputs: Optional[bool] = False,
639+
regex: bool = False,
640+
allow_multiple_outputs: bool = False,
641641
**kwargs: object,
642642
) -> Union[List[Union[Directory, Any, None]], List[Directory], None]:
643643

@@ -746,10 +746,11 @@ def _sample_outputs_list_to_dict(
746746
engine_name = directory.name.split("_")[-1]
747747
if engine_name.endswith(".tar.gz"):
748748
engine_name = engine_name.replace(".tar.gz", "")
749-
if engine_name not in ENGINES:
749+
if engine_name not in ENGINES and engine_name != "sample-outputs":
750750
raise ValueError(
751-
f"The name of the 'sample-outputs' directory should "
752-
f"end with an engine name (one of the {ENGINES}). "
751+
f"The name of the sample-outputs directory should be"
752+
f"`sample-outputs` or shoud start with `sample-outputs_` and "
753+
f"end with an engine name (one of the {ENGINES})."
753754
f"However, the name is {directory.name}."
754755
)
755756
engine_to_numpydir_map[engine_name] = directory

src/sparsezoo/objects/directories.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ class SelectDirectory(Directory):
197197
:param parent_directory: path of the parent SelectDirectory
198198
:param stub_params: dictionary of zoo stub params that this directory
199199
was specified with
200+
:param tar_directory: optional pointer to the tar_directory
201+
of this directory. By default, when downloading the directory
202+
in question, we should download and extract the tarball.
200203
"""
201204

202205
def __init__(
@@ -207,6 +210,7 @@ def __init__(
207210
url: Optional[str] = None,
208211
parent_directory: Optional[str] = None,
209212
stub_params: Optional[Dict[str, str]] = None,
213+
tar_directory: Optional[Directory] = None,
210214
):
211215
self._default, self._available = None, None
212216

@@ -217,7 +221,7 @@ def __init__(
217221
url=url,
218222
parent_directory=parent_directory,
219223
)
220-
224+
self.tar_directory = tar_directory
221225
self._stub_params = stub_params or {}
222226
self.files_dict = self.files_to_dictionary()
223227

src/sparsezoo/objects/directory.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,18 +168,26 @@ def download(
168168
"Please make sure that `destination_path` argument is not None."
169169
)
170170

171+
# If tar_directory is not None, then we are downloading
172+
# the directory as a tar archive file
173+
target_directory = (
174+
self if getattr(self, "tar_directory", None) is None else self.tar_directory
175+
)
176+
171177
# Directory can represent a tar file.
172-
if self.is_archive:
173-
new_file_path = os.path.join(destination_path, self.name)
178+
# In this case, we download the tar file and unzip it.
179+
if target_directory.is_archive:
180+
new_file_path = os.path.join(destination_path, target_directory.name)
174181
for attempt in range(retries):
175182
try:
176183
download_file(
177-
url_path=self.url,
184+
url_path=target_directory.url,
178185
dest_path=new_file_path,
179186
overwrite=overwrite,
180187
)
181188

182-
self._path = new_file_path
189+
target_directory._path = new_file_path
190+
target_directory.unzip()
183191
return
184192

185193
except Exception as err:
@@ -192,13 +200,15 @@ def download(
192200

193201
# Directory can represent a folder or directory.
194202
else:
195-
for file in self.files:
203+
for file in target_directory.files:
196204
file.download(
197205
destination_path=destination_path,
198206
)
199-
file._path = os.path.join(destination_path, self.name, file.name)
207+
file._path = os.path.join(
208+
destination_path, target_directory.name, file.name
209+
)
200210

201-
self._path = os.path.join(destination_path, self.name)
211+
target_directory._path = os.path.join(destination_path, target_directory.name)
202212

203213
def get_file(self, file_name: str) -> Optional[File]:
204214
"""

src/sparsezoo/objects/file.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def path(self):
9595
elif not os.path.exists(self._path):
9696
self.download()
9797

98-
return self._path
98+
return self._path or self.path
9999

100100
@classmethod
101101
def from_dict(

tests/sparsezoo/model/test_model.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import pytest
2424

2525
from sparsezoo import Model
26-
from sparsezoo.objects.directories import SelectDirectory
2726

2827

2928
files_ic = {
@@ -184,10 +183,6 @@ def setup(self, stub, clone_sample_outputs, expected_files):
184183
temp_dir = tempfile.TemporaryDirectory(dir="/tmp")
185184
model = Model(stub, temp_dir.name)
186185
model.download()
187-
# since downloading the `deployment` file is
188-
# disabled by default, we need to do it
189-
# explicitly
190-
model.deployment.download()
191186
self._add_mock_files(temp_dir.name, clone_sample_outputs=clone_sample_outputs)
192187
model = Model(temp_dir.name)
193188

@@ -342,49 +337,6 @@ def test_model_gz_extraction_from_local_files(stub: str):
342337
"imagenet/pruned-moderate",
343338
],
344339
)
345-
def test_model_deployment_directory(stub):
346-
temp_dir = tempfile.TemporaryDirectory(dir="/tmp")
347-
expected_deployment_files = ["model.onnx"]
348-
349-
model = Model(stub, temp_dir.name)
350-
assert model.deployment_tar.is_archive
351-
# download and extract deployment tar
352-
deployment_dir_path = model.deployment_directory_path
353-
354-
# deployment and deployment_tar should be point to the same files
355-
assert deployment_dir_path == model.deployment_tar.path == model.deployment.path
356-
# make sure that the model contains expected files
357-
assert set(os.listdir(temp_dir.name)) == {"deployment.tar.gz", "deployment"}
358-
assert (
359-
os.listdir(os.path.join(temp_dir.name, "deployment"))
360-
== expected_deployment_files
361-
)
362-
363-
assert isinstance(model.deployment, SelectDirectory)
364-
# TODO: this should be 1. However, the API is returning for `deployment` file type
365-
# both `model.onnx` and `deployment/model.onnx`.
366-
# This should probably be fixed on the API side
367-
assert (
368-
len(model.deployment.files) == 2
369-
) # should be == len(expected_deployment_files)
370-
371-
assert isinstance(model.deployment_tar, SelectDirectory)
372-
assert len(model.deployment_tar.files) == len(expected_deployment_files)
373-
assert not model.deployment_tar.is_archive
374-
375-
# test recreating the model from the local files
376-
model = Model(temp_dir.name)
377-
378-
assert isinstance(model.deployment, SelectDirectory)
379-
assert len(model.deployment.files) == len(expected_deployment_files)
380-
381-
assert isinstance(model.deployment_tar, SelectDirectory)
382-
assert len(model.deployment_tar.files) == len(expected_deployment_files)
383-
assert not model.deployment_tar.is_archive
384-
385-
shutil.rmtree(temp_dir.name)
386-
387-
388340
def _extraction_test_helper(model: Model):
389341
# download and extract model.onnx.tar.gz
390342
# path should point to extracted model.onnx file

tests/sparsezoo/model/test_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,14 @@ def test_setup_model_from_paths(self, setup):
172172
"recipe.md",
173173
"model.onnx",
174174
"model.onnx.tar.gz",
175-
"sample-inputs.tar.gz",
175+
"sample-inputs",
176176
}
177177
check_extraneous_files(expected_files, temp_dir, ignore_external_data)
178178

179179
def test_setup_model_from_objects(self, setup):
180180
stub, temp_dir, download_dir, ignore_external_data = setup
181181
model = Model(stub, download_dir.name)
182182
model.download()
183-
model.sample_inputs.unzip()
184183

185184
training = model.training
186185
deployment = model.deployment

0 commit comments

Comments
 (0)