Skip to content

Commit

Permalink
Merge pull request #155 from Visual-Behavior/trt_export_from_onnx
Browse files Browse the repository at this point in the history
Trt export from onnx
  • Loading branch information
thibo73800 authored Mar 18, 2022
2 parents 567a376 + 27214bd commit ea69d95
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 13 deletions.
88 changes: 75 additions & 13 deletions alonet/torch2trt/base_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import onnx
import onnx_graphsurgeon as gs
import tensorrt as trt

import pycuda.driver as cuda
prod_package_error = None
except Exception as prod_package_error:
pass
Expand All @@ -21,6 +21,8 @@
from contextlib import redirect_stdout, ExitStack
from alonet.torch2trt.onnx_hack import scope_name_workaround, get_scope_names, rename_tensors_
from alonet.torch2trt import TRTEngineBuilder, TRTExecutor, utils
from alonet.torch2trt.utils import get_nodes_by_op, rename_nodes_



class BaseTRTExporter:
Expand Down Expand Up @@ -51,6 +53,7 @@ def __init__(
operator_export_type=None,
dynamic_axes: Union[Dict[str, Dict[int, str]], Dict[str, List[int]]] = None,
opt_profiles: Dict[str, Tuple[List[int]]] = None,
skip_adapt_graph=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -108,6 +111,7 @@ def __init__(
self.custom_opset = None # to be redefine in child class if needed
self.use_scope_names = use_scope_names
self.operator_export_type = operator_export_type
self.skip_adapt_graph = skip_adapt_graph
if dynamic_axes is not None:
assert opt_profiles is not None, "If dynamic_axes are to be used, opt_profiles must be provided"
assert isinstance(dynamic_axes, dict)
Expand All @@ -117,13 +121,19 @@ def __init__(
onnx_dir = os.path.split(onnx_path)[0]
onnx_file_name = os.path.split(onnx_path)[1]
model_name = onnx_file_name.split(".")[0]
self.adapted_onnx_path = os.path.join(onnx_dir, "trt_" + onnx_file_name)

if not self.skip_adapt_graph:
self.adapted_onnx_path = os.path.join(onnx_dir, "trt_" + onnx_file_name)
else:
self.adapted_onnx_path = os.path.join(onnx_dir, onnx_file_name)

self.engine_path = os.path.join(onnx_dir, model_name + f"_{precision.lower()}.engine")

if self.verbose:
trt_logger = trt.Logger(trt.Logger.VERBOSE)
else:
trt_logger = trt.Logger(trt.Logger.WARNING)

self.engine_builder = TRTEngineBuilder(self.adapted_onnx_path, logger=trt_logger, opt_profiles=opt_profiles)

if precision.lower() == "fp32":
Expand All @@ -147,15 +157,59 @@ def build_torch_model(self):
pass
raise Exception("Child class should implement this method")


def adapt_graph(self, graph):
"""Modify ONNX graph to ensure compability between ONNX and TensorRT
Returns
-------
graph: onnx_graphsurgeon.Graph
"""
pass
raise Exception("Child class should implement this method")
return graph

def _adapt_graph(self, graph):
"""Modify ONNX graph to ensure compability between ONNX and TensorRT
Returns
-------
graph: onnx_graphsurgeon.Graph
"""
clip_nodes = get_nodes_by_op("Clip", graph)
def handle_op_Clip(node: gs.Node):
max_constant = np.array(np.finfo(np.float32).max, dtype=np.float32)
if "value" in node.inputs[1].i().inputs[0].attrs:
min_constant = node.inputs[1].i().inputs[0].attrs["value"].values.astype(np.float32)
if len(node.inputs[2].inputs) > 0:
max_constant = node.inputs[2].i().inputs[0].attrs["value"].values.astype(np.float32)
elif "to" in node.inputs[1].i().inputs[0].attrs:
min_constant = np.array(np.finfo(np.float32).min, dtype=np.float32)
else:
raise Exception("Error")
node.inputs.pop(1)
node.inputs.insert(1, gs.Constant(name=node.name + "_min", values=min_constant))
node.inputs.pop(2)
node.inputs.insert(2, gs.Constant(name=node.name + "_max", values=max_constant))

for n in clip_nodes:
handle_op_Clip(n)

from onnxsim import simplify
model = onnx.load(self.onnx_path)
check = False
model_simp, check = simplify(model)

if check:
print("\n[INFO] Simplified ONNX model validated. Graph optimized...")
graph = gs.import_onnx(model_simp)
graph.toposort()
graph.cleanup()
else:
print("\n[INFO] ONNX model was not validated.")


# Call the child class for specific graph adapation
graph = self.adapt_graph(graph)
return graph

def prepare_sample_inputs(self) -> Tuple[Tuple[torch.Tensor], Dict[str, Union[torch.Tensor, None]]]:
"""
Expand Down Expand Up @@ -247,6 +301,7 @@ def _torch2onnx(self):
number2scope = get_scope_names(onnx_export_log, strict=False)
graph = gs.import_onnx(onnx.load(self.onnx_path))
graph = rename_tensors_(graph, number2scope, verbose=True)
graph = rename_nodes_(graph, True)
onnx.save(gs.export_onnx(graph), self.onnx_path)

print("Saved ONNX at:", self.onnx_path)
Expand All @@ -265,15 +320,15 @@ def _onnx2engine(self, **kwargs):
if prod_package_error is not None:
raise prod_package_error

graph = gs.import_onnx(onnx.load(self.onnx_path))
graph.toposort()

# === Modify ONNX graph for TensorRT compability
graph = self.adapt_graph(graph, **kwargs)
utils.print_graph_io(graph)
if not self.skip_adapt_graph:
graph = gs.import_onnx(onnx.load(self.onnx_path))
graph.toposort()

# === Export adapted onnx for TRT engine
onnx.save(gs.export_onnx(graph), self.adapted_onnx_path)
# === Modify ONNX graph for TensorRT compability
graph = self._adapt_graph(graph, **kwargs)
utils.print_graph_io(graph)
# === Export adapted onnx for TRT engine
onnx.save(gs.export_onnx(graph), self.adapted_onnx_path)

# === Build engine
self.engine_builder.export_engine(self.engine_path)
Expand All @@ -286,7 +341,7 @@ def sanity_check(self, engine, sample_inputs, sample_outputs):
threshold = 1e-1
check = True
# Get engine info
model = TRTExecutor(engine)
model = TRTExecutor(engine, stream=cuda.Stream())
model.print_bindings_info()
# Prepare engine inputs
for i in range(len(sample_inputs)):
Expand All @@ -302,6 +357,7 @@ def sanity_check(self, engine, sample_inputs, sample_outputs):
m_outputs = model.execute()
print("==== Absolute / relavtive error:")
for out in m_outputs:
print('out', m_outputs[out])
diff = m_outputs[out].astype(float) - sample_outputs[out].astype(float)
abs_err = np.abs(diff)
rel_err = np.abs(diff / (sample_outputs[out] + 1e-6)) # Avoid div by zero
Expand Down Expand Up @@ -332,7 +388,13 @@ def add_argparse_args(parent_parser):
default=None,
help="/path/onnx/will/be/exported, by default set as ~/.aloception/weights/MODEL/MODEL.onnx",
)
parser.add_argument("--skip_adapt_graph", action="store_true", help="Skip the adapt graph")
parser.add_argument("--batch_size", type=int, default=1, help="Engine batch size, default = 1")
parser.add_argument("--precision", type=str, default="fp32", help="fp32/fp16/mix, default FP32")
parser.add_argument("--verbose", action="store_true", help="Helpful when debugging")
parser.add_argument(
"--use_scope_names",
action="store_true",
help="Save scope names in onnx, to get profiles in inference by default %(default)s",
)
return parent_parser
34 changes: 34 additions & 0 deletions alonet/torch2trt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,37 @@ def execute_sync(context, bindings, inputs, outputs):
for out in outputs:
out.host = out.host.reshape(out.shape)
return [out.host for out in outputs]



def rename_nodes_(graph, verbose=False):

dont_rename = [v.name for v in graph.inputs + graph.outputs]

for node in graph.nodes:
if node.name not in dont_rename:
# Replace name by output name to include in profiling
node.name = node.outputs[0].name
# If the node does not have name, try to replace by inputs tensors to it
try:
id_node = int(node.name)
node_is_int = True
except:
node_is_int = False

if node_is_int:
for inode in node.inputs:
try: # Only for named inputs
int(inode.name)
inode_is_int = True
except:
inode_is_int = False

# Input named, change tensor name
if not inode_is_int:
new_name = inode.name + "_" + str(id_node)
if verbose:
print(f" changed {node.name} to {new_name}")
node.name = new_name

return graph

0 comments on commit ea69d95

Please sign in to comment.