@@ -1042,14 +1042,16 @@ def save_model(self, output_dir: Optional[str]=None):
1042
1042
def export_model (self ,
1043
1043
input_spec = None ,
1044
1044
load_best_model = False ,
1045
- output_dir : Optional [str ]= None ):
1046
- """ Export paddle inference model.
1045
+ output_dir : Optional [str ]= None ,
1046
+ model_format : Optional [str ]= "paddle" ):
1047
+ """ Export paddle inference model or onnx model.
1047
1048
1048
1049
Args:
1049
1050
input_spec (paddle.static.InputSpec, optional): InputSpec describes the signature information of the model input,
1050
1051
such as shape , dtype , name. Defaults to None.
1051
1052
load_best_model (bool, optional): Load best model. Defaults to False.
1052
1053
output_dir (Optional[str], optional): Output dir to save the exported model. Defaults to None.
1054
+ model_format (Optional[str], optional): Export model format. There are two options: paddle or onnx, defaults to paddle.
1053
1055
"""
1054
1056
1055
1057
if output_dir is None :
@@ -1079,14 +1081,26 @@ def export_model(self,
1079
1081
model = unwrap_model (self .model )
1080
1082
model .eval ()
1081
1083
1082
- # Convert to static graph with specific input description
1083
- model = paddle .jit .to_static (model , input_spec = input_spec )
1084
-
1085
- # Save in static graph model.
1086
- save_path = os .path .join (output_dir , "inference" , "infer" )
1087
- logger .info ("Exporting inference model to %s" % save_path )
1088
- paddle .jit .save (model , save_path )
1089
- logger .info ("Inference model exported." )
1084
+ model_format = model_format .lower ()
1085
+ if model_format == "paddle" :
1086
+ # Convert to static graph with specific input description
1087
+ model = paddle .jit .to_static (model , input_spec = input_spec )
1088
+
1089
+ # Save in static graph model.
1090
+ save_path = os .path .join (output_dir , "inference" , "infer" )
1091
+ logger .info ("Exporting inference model to %s" % save_path )
1092
+ paddle .jit .save (model , save_path )
1093
+ logger .info ("Inference model exported." )
1094
+ elif model_format == "onnx" :
1095
+ # Export ONNX model.
1096
+ save_path = os .path .join (output_dir , "onnx" , "model" )
1097
+ logger .info ("Exporting ONNX model to %s" % save_path )
1098
+ paddle .onnx .export (model , save_path , input_spec = input_spec )
1099
+ logger .info ("ONNX model exported." )
1100
+ else :
1101
+ logger .info (
1102
+ "This export format is not supported, please select paddle or onnx!"
1103
+ )
1090
1104
1091
1105
def _save_checkpoint (self , model , metrics = None ):
1092
1106
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
0 commit comments