Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 8d034f7

Browse files
author
George Ohashi
committed
comments
1 parent 8dc4a67 commit 8d034f7

File tree

4 files changed

+48
-24
lines changed

4 files changed

+48
-24
lines changed

src/sparsezoo/api/graphql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def make_request(
8282
query = self.parse_query(
8383
operation_body=operation_body, arguments=arguments, fields=fields
8484
)
85-
85+
print(url or f"{BASE_API_URL}/v2/graphql")
8686
response = requests.post(
8787
url=url or f"{BASE_API_URL}/v2/graphql",
8888
json={

src/sparsezoo/api/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Callable, Dict, List
15+
from typing import Any, Callable, Dict, List
1616

1717

1818
def to_camel_case(string: str):
@@ -32,7 +32,7 @@ def to_snake_case(string: str):
3232

3333

3434
def map_keys(
35-
dictionary: Dict[str, str], mapper: Callable[[str], str]
35+
dictionary: Dict[str, Any], mapper: Callable[[str], str]
3636
) -> Dict[str, str]:
3737
"""
3838
Given a dictionary, update its keys to a given mapper callable.
@@ -44,8 +44,7 @@ def map_keys(
4444
if isinstance(value, List) or isinstance(value, Dict):
4545
value_type = type(value)
4646
mapped_dict[mapper(key)] = value_type(
47-
map_keys(dictionary=sub_dict, mapper=to_snake_case)
48-
for sub_dict in value
47+
map_keys(dictionary=sub_dict, mapper=mapper) for sub_dict in value
4948
)
5049
else:
5150
mapped_dict[mapper(key)] = value

src/sparsezoo/model/utils.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,12 @@
7575
r"/(?P<architecture>[\.A-z0-9_]+)(-(?P<sub_architecture>[\.A-z0-9_]+))?"
7676
r"/(?P<framework>[\.A-z0-9_]+)"
7777
r"/(?P<repo>[\.A-z0-9_]+)"
78-
r"/(?P<dataset>[\.A-z0-9_]+)"
78+
r"/(?P<dataset>[\.A-z0-9_]+)(-(?P<training_scheme>[\.A-z0-9_]+))?"
7979
r"/(?P<sparse_tag>[\.A-z0-9_-]+)"
8080
)
8181

8282
STUB_V2_REGEX_EXPR = (
83+
r"^(zoo:)?"
8384
r"(?P<architecture>[\.A-z0-9_]+)"
8485
r"(-(?P<sub_architecture>[\.A-z0-9_]+))?"
8586
r"-(?P<source_dataset>[\.A-z0-9_]+)"
@@ -148,22 +149,35 @@ def load_files_from_stub(
148149
],
149150
)
150151

151-
if len(models):
152+
matching_models = len(models)
153+
if matching_models == 0:
154+
raise ValueError(
155+
f"No matching models found with stub: {stub}." "Please try another stub"
156+
)
157+
if matching_models > 1:
158+
logging.warning(
159+
f"{len(models)} found from the stub: {stub}"
160+
"Using the first model to obtain metadata."
161+
"Proceed with caution"
162+
)
163+
164+
if matching_models:
165+
model = models[0]
152166

153-
model_id = models[0]["model_id"]
167+
model_id = model["model_id"]
154168

155-
files = models[0].get("files")
169+
files = model.get("files")
156170
include_file_download_url(files)
157171
files = restructure_request_json(request_json=files)
158172

159173
if params is not None:
160174
files = filter_files(files=files, params=params)
161175

162-
training_results = models[0].get("training_results")
176+
training_results = model.get("training_results")
163177

164-
benchmark_results = models[0].get("benchmark_results")
178+
benchmark_results = model.get("benchmark_results")
165179

166-
model_onnx_size_compressed_bytes = models[0]["model_onnx_size_compressed_bytes"]
180+
model_onnx_size_compressed_bytes = model["model_onnx_size_compressed_bytes"]
167181

168182
throughput_results = _parse_results_metrics(
169183
results=benchmark_results, parser=ThroughputResults
@@ -601,9 +615,8 @@ def get_model_metadata_from_stub(stub: str) -> Dict[str, str]:
601615
return {}
602616

603617

604-
def is_stub(candidate: str):
605-
if candidate.startswith(ZOO_STUB_PREFIX):
606-
return True
607-
if bool(get_model_metadata_from_stub(candidate)):
608-
return True
609-
return False
618+
def is_stub(candidate: str) -> bool:
619+
return bool(
620+
re.match(STUB_V1_REGEX_EXPR, candidate)
621+
or re.match(STUB_V2_REGEX_EXPR, candidate)
622+
)

tests/sparsezoo/model/test_model.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,32 +137,44 @@ def _assert_validation_results_exist(model):
137137
"stub, clone_sample_outputs, expected_files",
138138
[
139139
(
140-
"zoo:cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/pruned-moderate", # noqa E501
140+
(
141+
"zoo:"
142+
"cv/classification/mobilenet_v1-1.0/"
143+
"pytorch/sparseml/imagenet/pruned-moderate"
144+
),
141145
True,
142146
files_ic,
143147
),
144148
(
145-
"zoo:nlp/question_answering/distilbert-none/pytorch/huggingface/squad/pruned80_quant-none-vnni", # noqa E501
149+
(
150+
"zoo:"
151+
"nlp/question_answering/distilbert-none/"
152+
"pytorch/huggingface/squad/pruned80_quant-none-vnni"
153+
),
146154
False,
147155
files_nlp,
148156
),
149157
(
150-
"zoo:cv/detection/yolov5-s/pytorch/ultralytics/coco/pruned_quant-aggressive_94", # noqa E501
158+
(
159+
"zoo:"
160+
"cv/detection/yolov5-s/"
161+
"pytorch/ultralytics/coco/pruned_quant-aggressive_94"
162+
),
151163
True,
152164
files_yolo,
153165
),
154166
(
155-
"mobilebert-squad_wikipedia_bookcorpus-14layer_pruned50.4block", # noqa E501
167+
"mobilebert-squad_wikipedia_bookcorpus-14layer_pruned50.4block",
156168
False,
157169
files_yolo,
158170
),
159171
(
160-
"yolov5-s-coco-pruned85_quantized", # noqa E501
172+
"yolov5-s-coco-pruned85_quantized",
161173
False,
162174
files_yolo,
163175
),
164176
(
165-
"resnet_v1-50-imagenet-channel30_pruned91", # noqa E501
177+
"resnet_v1-50-imagenet-channel30_pruned91",
166178
False,
167179
files_yolo,
168180
),

0 commit comments

Comments
 (0)