Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 4486de9

Browse files
dbogunowiczmgoin
andauthored
Avoid appending to external data when running onnx_save (#320)
* initial commit * fix docstrings * fix blunders in logic --------- Co-authored-by: Michael Goin <michael@neuralmagic.com>
1 parent 960b213 commit 4486de9

File tree

1 file changed

+40
-5
lines changed

1 file changed

+40
-5
lines changed

src/sparsezoo/utils/onnx.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import os
1617
from typing import Any, Dict, Optional, Tuple, Union
1718

1819
import numpy
@@ -48,8 +49,11 @@
4849
"group_four_block",
4950
"extract_node_id",
5051
"get_node_attributes",
52+
"EXTERNAL_ONNX_DATA_NAME",
5153
]
5254

55+
EXTERNAL_ONNX_DATA_NAME = "model.data"
56+
5357

5458
def onnx_includes_external_data(model: ModelProto) -> bool:
5559
"""
@@ -75,7 +79,10 @@ def onnx_includes_external_data(model: ModelProto) -> bool:
7579

7680

7781
def save_onnx(
78-
model: ModelProto, model_path: str, external_data_file: Optional[str] = None
82+
model: ModelProto,
83+
model_path: str,
84+
large_model_external_data_file: str = EXTERNAL_ONNX_DATA_NAME,
85+
external_data_file: Optional[str] = None,
7986
) -> bool:
8087
"""
8188
Save model to the given path.
@@ -88,11 +95,21 @@ def save_onnx(
8895
8996
:param model: The model to save.
9097
:param model_path: The path to save the model to.
98+
:param large_model_external_data_file: The default name to save the external
99+
data to if the model is too large to be saved as a single protobuf.
100+
If:
101+
- the model is too large to be saved as a single protobuf, AND
102+
- `external_data_file` is specified,
103+
then the external data of the model will be saved to `external_data_file`
104+
instead of `large_model_external_data_name`.
91105
:param external_data_file: The optional name save the external data to.
92106
:return True if the model was saved with external data, False otherwise.
93107
"""
94108
if external_data_file is not None:
95109
_LOGGER.debug(f"Saving with external data: {external_data_file}")
110+
_check_for_old_external_data(
111+
model_path=model_path, external_data_file=external_data_file
112+
)
96113
onnx.save(
97114
model,
98115
model_path,
@@ -104,16 +121,18 @@ def save_onnx(
104121

105122
if model.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF:
106123
_LOGGER.warning(
107-
"The ONNX model is too large to be saved as a single protobuf."
108-
"Saving with external data: 'model.data'"
124+
"The ONNX model is too large to be saved as a single protobuf. "
125+
f"Saving with external data: {large_model_external_data_file}"
126+
)
127+
_check_for_old_external_data(
128+
model_path=model_path, external_data_file=large_model_external_data_file
109129
)
110-
111130
onnx.save(
112131
model,
113132
model_path,
114133
save_as_external_data=True,
115134
all_tensors_to_one_file=True,
116-
location="model.data",
135+
location=large_model_external_data_file,
117136
)
118137
return True
119138

@@ -566,3 +585,19 @@ def _get_node_input(
566585
return node.input[index]
567586
else:
568587
return default
588+
589+
590+
def _check_for_old_external_data(model_path: str, external_data_file: str):
591+
old_external_data_file = os.path.join(
592+
os.path.dirname(model_path), external_data_file
593+
)
594+
if os.path.exists(old_external_data_file):
595+
_LOGGER.warning(
596+
f"Attempting to save external data for a model: {model_path} "
597+
f"to a directory:{os.path.dirname(model_path)} "
598+
f"that already contains external data file: {external_data_file}. "
599+
"The external data file will be overwritten."
600+
)
601+
os.remove(old_external_data_file)
602+
603+
return

0 commit comments

Comments
 (0)