Skip to content

Commit caa386f

Browse files
authored
Add ONNX Export (#2061)
* Add ONNX Export * deal with comments * deal with comments * deal with comments
1 parent fa8fe0e commit caa386f

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

paddlenlp/trainer/trainer_base.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,14 +1042,16 @@ def save_model(self, output_dir: Optional[str]=None):
10421042
def export_model(self,
10431043
input_spec=None,
10441044
load_best_model=False,
1045-
output_dir: Optional[str]=None):
1046-
""" Export paddle inference model.
1045+
output_dir: Optional[str]=None,
1046+
model_format: Optional[str]="paddle"):
1047+
""" Export paddle inference model or onnx model.
10471048
10481049
Args:
10491050
input_spec (paddle.static.InputSpec, optional): InputSpec describes the signature information of the model input,
10501051
such as shape , dtype , name. Defaults to None.
10511052
load_best_model (bool, optional): Load best model. Defaults to False.
10521053
output_dir (Optional[str], optional): Output dir to save the exported model. Defaults to None.
1054+
model_format (Optional[str], optional): Export model format. There are two options: paddle or onnx, defaults to paddle.
10531055
"""
10541056

10551057
if output_dir is None:
@@ -1079,14 +1081,26 @@ def export_model(self,
10791081
model = unwrap_model(self.model)
10801082
model.eval()
10811083

1082-
# Convert to static graph with specific input description
1083-
model = paddle.jit.to_static(model, input_spec=input_spec)
1084-
1085-
# Save in static graph model.
1086-
save_path = os.path.join(output_dir, "inference", "infer")
1087-
logger.info("Exporting inference model to %s" % save_path)
1088-
paddle.jit.save(model, save_path)
1089-
logger.info("Inference model exported.")
1084+
model_format = model_format.lower()
1085+
if model_format == "paddle":
1086+
# Convert to static graph with specific input description
1087+
model = paddle.jit.to_static(model, input_spec=input_spec)
1088+
1089+
# Save in static graph model.
1090+
save_path = os.path.join(output_dir, "inference", "infer")
1091+
logger.info("Exporting inference model to %s" % save_path)
1092+
paddle.jit.save(model, save_path)
1093+
logger.info("Inference model exported.")
1094+
elif model_format == "onnx":
1095+
# Export ONNX model.
1096+
save_path = os.path.join(output_dir, "onnx", "model")
1097+
logger.info("Exporting ONNX model to %s" % save_path)
1098+
paddle.onnx.export(model, save_path, input_spec=input_spec)
1099+
logger.info("ONNX model exported.")
1100+
else:
1101+
logger.info(
1102+
"This export format is not supported, please select paddle or onnx!"
1103+
)
10901104

10911105
def _save_checkpoint(self, model, metrics=None):
10921106
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ datasets
88
tqdm
99
paddlefsl
1010
sentencepiece
11+
paddle2onnx

0 commit comments

Comments
 (0)