17
17
import pathlib
18
18
from typing import Callable
19
19
20
- import onnx
21
-
22
- from onnxscript import ir
20
+ from onnxscript import ir , optimizer
23
21
from onnxscript .function_libs .torch_lib import registration
24
22
from onnxscript .ir import _external_data
25
23
26
- # Internal flag. Will go away.
27
- _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR = (
28
- os .getenv ("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR" ) == "1"
29
- )
30
-
31
24
32
25
@dataclasses .dataclass (frozen = True )
33
26
class _OnnxFunctionMeta :
@@ -49,8 +42,10 @@ class _OnnxFunctionMeta:
49
42
50
43
def optimize (model : ir .Model ) -> ir .Model :
51
44
"""Optimize the model."""
52
-
53
- # TODO(justinchuby): Use the optimizer
45
+ # Internal flag. Will go away.
46
+ enabled = os .getenv ("TORCH_ONNX_ENABLE_OPTIMIZATION" ) == "1"
47
+ if enabled :
48
+ optimizer .optimize_ir (model )
54
49
return model
55
50
56
51
@@ -81,45 +76,32 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike
81
76
"""Save the model with external data. The model is unchanged after saving."""
82
77
83
78
# TODO(#1835): Decide if we want to externalize large attributes as well
84
- if _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR :
85
- initializer_values = tuple (model .graph .initializers .values ())
86
- tensors = [v .const_value for v in initializer_values ]
87
- for tensor in tensors :
88
- if tensor is None :
89
- raise ValueError (
90
- "The model contains uninitialized initializer values. "
91
- "Please make sure all initializer values are initialized."
92
- )
93
- destination_path = pathlib .Path (model_path )
94
- base_dir = destination_path .parent
95
- data_path = f"{ destination_path .name } .data"
96
-
97
- external_tensors = _external_data .convert_tensors_to_external (
98
- tensors , # type: ignore[arg-type]
99
- base_dir ,
100
- data_path ,
101
- )
102
-
103
- # Replace the initializer values with external tensors and save the model
104
- for initializer , external_tensor in zip (initializer_values , external_tensors ):
105
- initializer .const_value = external_tensor
106
- ir .save (model , model_path )
107
-
108
- # Restore the original initializer values so the model is unchanged
109
- for initializer , tensor in zip (initializer_values , tensors ):
110
- initializer .const_value = tensor
111
-
112
- else :
113
- destination_path = pathlib .Path (model_path )
114
- # Create the directory if it does not exist
115
- data_path = f"{ destination_path .name } .data"
116
- proto = ir .serde .serialize_model (model )
117
- onnx .save_model (
118
- proto ,
119
- model_path ,
120
- save_as_external_data = True ,
121
- location = data_path ,
122
- )
79
+ initializer_values = tuple (model .graph .initializers .values ())
80
+ tensors = [v .const_value for v in initializer_values ]
81
+ for tensor in tensors :
82
+ if tensor is None :
83
+ raise ValueError (
84
+ "The model contains uninitialized initializer values. "
85
+ "Please make sure all initializer values are initialized."
86
+ )
87
+ destination_path = pathlib .Path (model_path )
88
+ base_dir = destination_path .parent
89
+ data_path = f"{ destination_path .name } .data"
90
+
91
+ external_tensors = _external_data .convert_tensors_to_external (
92
+ tensors , # type: ignore[arg-type]
93
+ base_dir ,
94
+ data_path ,
95
+ )
96
+
97
+ # Replace the initializer values with external tensors and save the model
98
+ for initializer , external_tensor in zip (initializer_values , external_tensors ):
99
+ initializer .const_value = external_tensor
100
+ ir .save (model , model_path )
101
+
102
+ # Restore the original initializer values so the model is unchanged
103
+ for initializer , tensor in zip (initializer_values , tensors ):
104
+ initializer .const_value = tensor
123
105
124
106
125
107
def get_torchlib_ops () -> list [_OnnxFunctionMeta ]:
0 commit comments