Skip to content

Commit

Permalink
TorchOnnxExportJob: Explicit input_names and output_names (#510)
Browse files Browse the repository at this point in the history
rwth-i6/returnn#1517

Co-authored-by: Albert Zeyer <zeyer@cs.rwth-aachen.de>
  • Loading branch information
kuacakuaca and albertz authored May 23, 2024
1 parent a4aa4a4 commit d782ce3
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion returnn/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import shutil
import subprocess as sp
from typing import Optional
from typing import Optional, Sequence

import i6_core.util as util

Expand Down Expand Up @@ -214,11 +214,15 @@ class TorchOnnxExportJob(Job):
Currently only supports PyTorch via tools/torch_export_to_onnx.py
"""

__sis_hash_exclude__ = {"input_names": None, "output_names": None}

def __init__(
self,
*,
returnn_config: ReturnnConfig,
checkpoint: PtCheckpoint,
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
device: str = "cpu",
returnn_python_exe: Optional[tk.Path] = None,
returnn_root: Optional[tk.Path] = None,
Expand All @@ -227,13 +231,32 @@ def __init__(
:param returnn_config: RETURNN config object
:param checkpoint: Path to the checkpoint for export
:param input_names: sequence of model input names.
If not specified, will automatically determine from `extern_data` when available in `returnn_config.config`.
:param output_names: sequence of model output names.
If not specified, will automatically determine from `model_outputs` when available in `returnn_config.config`.
:param device: target device for graph creation
:param returnn_python_exe: file path to the executable for running returnn (python binary or .sh)
:param returnn_root: file path to the RETURNN repository root folder
"""

self.returnn_config = returnn_config
self.checkpoint = checkpoint

# Get the list here, because ReturnnConfig serialization might potentially reorder via `sort_config=True`.
input_names = (
list(returnn_config.config["extern_data"].keys())
if ("extern_data" in returnn_config.config and input_names is None)
else input_names
)
output_names = (
list(returnn_config.config["model_outputs"].keys())
if ("model_outputs" in returnn_config.config and output_names is None)
else output_names
)
self.input_names = input_names
self.output_names = output_names

self.device = device
self.returnn_python_exe = util.get_returnn_python_exe(returnn_python_exe)
self.returnn_root = util.get_returnn_root(returnn_root)
Expand All @@ -260,6 +283,10 @@ def run(self):
"--verbosity",
"5",
]
if self.input_names:
cmd += ["--input_names", ",".join(self.input_names)]
if self.output_names:
cmd += ["--output_names", ",".join(self.output_names)]

util.create_executable("compile.sh", cmd) # convenience file for manual execution
sp.run(cmd, check=True)

0 comments on commit d782ce3

Please sign in to comment.