1313# limitations under the License.
1414
1515import logging
16+ import os
1617from typing import Any , Dict , Optional , Tuple , Union
1718
1819import numpy
4849 "group_four_block" ,
4950 "extract_node_id" ,
5051 "get_node_attributes" ,
52+ "EXTERNAL_ONNX_DATA_NAME" ,
5153]
5254
55+ EXTERNAL_ONNX_DATA_NAME = "model.data"
56+
5357
5458def onnx_includes_external_data (model : ModelProto ) -> bool :
5559 """
@@ -75,7 +79,10 @@ def onnx_includes_external_data(model: ModelProto) -> bool:
7579
7680
7781def save_onnx (
78- model : ModelProto , model_path : str , external_data_file : Optional [str ] = None
82+ model : ModelProto ,
83+ model_path : str ,
84+ large_model_external_data_file : str = EXTERNAL_ONNX_DATA_NAME ,
85+ external_data_file : Optional [str ] = None ,
7986) -> bool :
8087 """
8188 Save model to the given path.
@@ -88,11 +95,21 @@ def save_onnx(
8895
8996 :param model: The model to save.
9097 :param model_path: The path to save the model to.
98+ :param large_model_external_data_file: The default name to save the external
99+ data to if the model is too large to be saved as a single protobuf.
100+ If:
101+ - the model is too large to be saved as a single protobuf, AND
102+ - `external_data_file` is specified,
103+ then the external data of the model will be saved to `external_data_file`
104+ instead of `large_model_external_data_name`.
91105 :param external_data_file: The optional name save the external data to.
92106 :return True if the model was saved with external data, False otherwise.
93107 """
94108 if external_data_file is not None :
95109 _LOGGER .debug (f"Saving with external data: { external_data_file } " )
110+ _check_for_old_external_data (
111+ model_path = model_path , external_data_file = external_data_file
112+ )
96113 onnx .save (
97114 model ,
98115 model_path ,
@@ -104,16 +121,18 @@ def save_onnx(
104121
105122 if model .ByteSize () > onnx .checker .MAXIMUM_PROTOBUF :
106123 _LOGGER .warning (
107- "The ONNX model is too large to be saved as a single protobuf."
108- "Saving with external data: 'model.data'"
124+ "The ONNX model is too large to be saved as a single protobuf. "
125+ f"Saving with external data: { large_model_external_data_file } "
126+ )
127+ _check_for_old_external_data (
128+ model_path = model_path , external_data_file = large_model_external_data_file
109129 )
110-
111130 onnx .save (
112131 model ,
113132 model_path ,
114133 save_as_external_data = True ,
115134 all_tensors_to_one_file = True ,
116- location = "model.data" ,
135+ location = large_model_external_data_file ,
117136 )
118137 return True
119138
@@ -566,3 +585,19 @@ def _get_node_input(
566585 return node .input [index ]
567586 else :
568587 return default
588+
589+
590+ def _check_for_old_external_data (model_path : str , external_data_file : str ):
591+ old_external_data_file = os .path .join (
592+ os .path .dirname (model_path ), external_data_file
593+ )
594+ if os .path .exists (old_external_data_file ):
595+ _LOGGER .warning (
596+ f"Attempting to save external data for a model: { model_path } "
597+ f"to a directory:{ os .path .dirname (model_path )} "
598+ f"that already contains external data file: { external_data_file } . "
599+ "The external data file will be overwritten."
600+ )
601+ os .remove (old_external_data_file )
602+
603+ return
0 commit comments