Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨[Feature] Possibility to export nn.InstanceNorm2d #3265

Open
albertodallolio opened this issue Oct 28, 2024 · 10 comments
Open

✨[Feature] Possibility to export nn.InstanceNorm2d #3265

albertodallolio opened this issue Oct 28, 2024 · 10 comments
Assignees
Labels
feature request New feature or request

Comments

@albertodallolio
Copy link

What

I am trying to export a model that contains nn.InstanceNorm2d (the call from the module can be found here) operation and I get the following error:

RuntimeError: [Error thrown at core/conversion/converters/impl/batch_norm.cpp:199] Expected instance_norm_plugin to be true but got false
Unable to create instance_norm plugin from TensorRT plugin registry%input.5 : Tensor = aten::instance_norm(%input.3, %16, %16, %16, %16, %14, %17, %18, %14), scope: __module.fnet/__module.fnet.norm1 # /opt/conda/lib/python3.11/site-packages/torch/nn/functional.py:2866:0

If I understand correctly this issue comes from the fact that there is no conversion for such operation. I know that there is one for instance_norm though. Do you have alternative solutions you can suggest?

My workspace

Docker image from pytorch: pytorch/pytorch:2.5.0-cuda12.1-cudnn9-devel
then run:

pip install tensorrt==10.5.0
pip install torch_tensorrt

Thanks in advance for your help.

@narendasan
Copy link
Collaborator

Are you able to use the dynamo frontend @albertodallolio? You can still use TorchScript after if needed. But dynamo has way more comprehensive coverage of the PyTorch opset

@albertodallolio
Copy link
Author

Thanks a lot for the fast response.

I think your suggestion moved something. I am now getting a different error:

E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0] Error while creating guard:
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0] Name: ''
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]     Source: shape_env
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]     Create Function: SHAPE_ENV
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]     Guard Types: None
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]     Code List: None
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]     Object Weakref: None
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]     Guarded Class Weakref: None
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0] Traceback (most recent call last):
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_guards.py", line 281, in create
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]     return self.create_fn(builder, self)
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/guards.py", line 1836, in SHAPE_ENV
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]     guards = output_graph.shape_env.produce_guards(
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4178, in produce_guards
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]     raise ConstraintViolationError(
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (_2, _3)! For more information, run with TORCH_LOGS="+dynamic".
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]   - Not all values of _2 = L['image1'].size()[2] in the specified range 135 <= _2 <= 2160 satisfy the generated guard Eq(Mod(((L['image1'].size()[2] - 1)//4) - 1, 2), 0).
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]   - Not all values of _3 = L['image1'].size()[3] in the specified range 135 <= _3 <= 3840 satisfy the generated guard Eq(Mod(((L['image1'].size()[3] - 1)//4) - 1, 2), 0).
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]   - Not all values of _2 = L['image1'].size()[2] in the specified range 135 <= _2 <= 2160 satisfy the generated guard Ne(Mod(((L['image1'].size()[2] - 1)//4) - 3, 4), 0).
E1028 17:21:33.478000 228 site-packages/torch/_guards.py:283] [0/0]   - Not all values of _3 = L['image1'].size()[3] in the specified range 135 <= _3 <= 3840 satisfy the generated guard Ne(Mod(((L['image1'].size()[3] - 1)//4) - 3, 4), 0).
E1028 17:21:33.480000 228 site-packages/torch/_guards.py:285] [0/0] Created at:
E1028 17:21:33.480000 228 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 615, in transform
E1028 17:21:33.480000 228 site-packages/torch/_guards.py:285] [0/0]     tracer = InstructionTranslator(
E1028 17:21:33.480000 228 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2670, in __init__
E1028 17:21:33.480000 228 site-packages/torch/_guards.py:285] [0/0]     output=OutputGraph(
E1028 17:21:33.480000 228 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 317, in __init__
E1028 17:21:33.480000 228 site-packages/torch/_guards.py:285] [0/0]     self.init_ambient_guards()
E1028 17:21:33.480000 228 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 463, in init_ambient_guards
E1028 17:21:33.480000 228 site-packages/torch/_guards.py:285] [0/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 560, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1477, in inner
    raise constraint_violation_error
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1432, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
           ^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 796, in _compile_inner
    check_fn = CheckFunctionManager(
               ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/guards.py", line 2261, in __init__
    guard.create(builder)
  File "/opt/conda/lib/python3.11/site-packages/torch/_guards.py", line 281, in create
    return self.create_fn(builder, self)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/guards.py", line 1836, in SHAPE_ENV
    guards = output_graph.shape_env.produce_guards(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4178, in produce_guards
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (_2, _3)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of _2 = L['image1'].size()[2] in the specified range 135 <= _2 <= 2160 satisfy the generated guard Eq(Mod(((L['image1'].size()[2] - 1)//4) - 1, 2), 0).
  - Not all values of _3 = L['image1'].size()[3] in the specified range 135 <= _3 <= 3840 satisfy the generated guard Eq(Mod(((L['image1'].size()[3] - 1)//4) - 1, 2), 0).
  - Not all values of _2 = L['image1'].size()[2] in the specified range 135 <= _2 <= 2160 satisfy the generated guard Ne(Mod(((L['image1'].size()[2] - 1)//4) - 3, 4), 0).
  - Not all values of _3 = L['image1'].size()[3] in the specified range 135 <= _3 <= 3840 satisfy the generated guard Ne(Mod(((L['image1'].size()[3] - 1)//4) - 3, 4), 0).

Suggested fixes:
  __2 = Dim('__2', min=17, max=270)
  __3 = Dim('__3', min=17, max=480)
  _2 = 8*__2 - 1
  _3 = 8*__3 - 1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/my_path/trt_generation/torch_2_trt_fmaps_new.py", line 86, in <module>
    trt_model = trt.compile(
                ^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 266, in compile
    exp_program = dynamo_trace(
                  ^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_tracer.py", line 83, in trace
    exp_program = export(
                  ^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/__init__.py", line 270, in export
    return _export(
           ^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1017, in wrapper
    raise e
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 990, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/exported_program.py", line 114, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1880, in _export
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1224, in _strict_export
    return _strict_export_lower_to_aten_ir(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1252, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 576, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.UserError: Constraints violated (_2, _3)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of _2 = L['image1'].size()[2] in the specified range 135 <= _2 <= 2160 satisfy the generated guard Eq(Mod(((L['image1'].size()[2] - 1)//4) - 1, 2), 0).
  - Not all values of _3 = L['image1'].size()[3] in the specified range 135 <= _3 <= 3840 satisfy the generated guard Eq(Mod(((L['image1'].size()[3] - 1)//4) - 1, 2), 0).
  - Not all values of _2 = L['image1'].size()[2] in the specified range 135 <= _2 <= 2160 satisfy the generated guard Ne(Mod(((L['image1'].size()[2] - 1)//4) - 3, 4), 0).
  - Not all values of _3 = L['image1'].size()[3] in the specified range 135 <= _3 <= 3840 satisfy the generated guard Ne(Mod(((L['image1'].size()[3] - 1)//4) - 3, 4), 0).

Suggested fixes:
  __2 = Dim('__2', min=17, max=270)
  __3 = Dim('__3', min=17, max=480)
  _2 = 8*__2 - 1
  _3 = 8*__3 - 1

This is the code I am running:

model = my_model.cuda().eval()
trt_model = trt.compile(
        model,
        inputs=[
            trt.Input(
                min_shape=[1, 1, 135, 135],  # Minimum input shape
                opt_shape=[1, 1, 135, 135],  # Optimal input shape
                max_shape=[1, 1, 2160, 3840],   # Maximum input shape
                dtype=torch.float32,
            ),
            trt.Input(
                min_shape=[1, 1, 135, 135],  # Minimum input shape
                opt_shape=[1, 1, 135, 135],  # Optimal input shape
                max_shape=[1, 1, 2160, 3840],   # Maximum input shape
                dtype=torch.float32,
            )
        ],
        ir="dynamo",
        enabled_precisions={torch.float} 
    )

Any idea what the error might be? I was actually planning to use dynamic input images, so it would be great to make it work with Dynamo. Thanks a lot

@narendasan
Copy link
Collaborator

@peri044 can you take a look at the dynamic shape issue here? Do you know if all values need to be satisfied in range or just the bounds and some intermediate values? My best guess is this is probably telling you that the lower bound does not fit properly with the shapes further into the network.

  - Not all values of _2 = L['image1'].size()[2] in the specified range 135 <= _2 <= 2160 satisfy the generated guard Eq(Mod(((L['image1'].size()[2] - 1)//4) - 1, 2), 0).

@albertodallolio
Copy link
Author

It seems that ALL values in the range 135 <= _2 <= 2160 for the shape need to satisfy that equality... Can that be?

@peri044
Copy link
Collaborator

peri044 commented Oct 28, 2024

There are guards that are generated internally by dynamo by Pytorch which sometimes error out. Could you try exporting using
this https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/utils.py#L29-L35 with Pytorch 2.5?

@albertodallolio
Copy link
Author

albertodallolio commented Oct 28, 2024

Hey @peri044 thanks for your suggestion.

Maybe a stupid question from my side: how should I parametrize the dynamic_shapes argument here? What does 1 in {1: seq_len} refer to?
Also I guess inputs_value are just some dummy-valid inputs to the model.

ep = torch.export._trace._export(
        model,
        (inputs_value,),
        dynamic_shapes=({1: seq_len},),
        strict=False,
        allow_complex_guards_as_runtime_asserts=True,
    )

Reposting my compile call here for reference:

trt_model = trt.compile(
        model,
        inputs=[
            trt.Input(
                min_shape=[1, 1, 135, 135],  # Minimum input shape
                opt_shape=[1, 1, 135, 135],  # Optimal input shape
                max_shape=[1, 1, 2160, 3840],   # Maximum input shape
                dtype=torch.float32,
            ),
            trt.Input(
                min_shape=[1, 1, 135, 135],  # Minimum input shape
                opt_shape=[1, 1, 135, 135],  # Optimal input shape
                max_shape=[1, 1, 2160, 3840],   # Maximum input shape
                dtype=torch.float32,
            )
        ],
        ir="dynamo",
        enabled_precisions={torch.float} 
    )

Thanks a lot in advance again for your help. Highly appreciated 👍

@albertodallolio
Copy link
Author

Also, I am failing to run torch.export._trace._export in the first place.

Getting:

AttributeError: module 'torch.export' has no attribute '_trace'

My pytorch version is:

import torch
print(torch.__version__)
# 2.5.0+cu121

@peri044
Copy link
Collaborator

peri044 commented Oct 28, 2024

@albertodallolio The trt.compile in your code is a two step process. The steps it performs under the hood are

  1. call torch.export on the model to export the graph
  2. call dynamo.compile which compiles the graph into TensorRT.

Step 1) is failing in your case.
Can you try importing in this way ?

from torch.export._trace import _export

Here's a reference for configuring dynamic shapes : https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#constraints-dynamic-shapes

In your case, I believe it would probably be something like this (checkout the above references

dim_2 = torch.export.Dim("dim_2", min=135, max=2160)
dim_3 = torch.export.Dim("dim_3", min=135, max=3840)
dynamic_shapes = ({2: dim_2, 3: dim_3}, {2: dim_2, 3: dim_3},)

@albertodallolio
Copy link
Author

Thanks for your suggestion. I thin we are actually getting closer: the export step passed through (with a bunch of warnings) but now the compile steps errors out with:

Traceback (most recent call last):
  File "/my_module/trt_generation/torch_2_trt_fmaps_new.py", line 134, in <module>
    trt_model = trt.dynamo.compile(
                ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 288, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 436, in compile_module
    submodule_inputs = partitioning.construct_submodule_inputs(submodule)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/partitioning/common.py", line 119, in construct_submodule_inputs
    raise AssertionError(
AssertionError: Input sum_dim_int_list does not contain metadata. Please ensure you have exported the graph correctly

Looks like it is failing somewhere here. But I printed the nodes of the exported graph as below and could not find sum_dim_int_list anywhere:

module_inputs = [
    node for node in exported_module.graph.nodes if node.op == "placeholder"
]
for module_input in module_inputs:
    print("=======")
    print(module_input)
    print(module_input.name)
    print(module_input.meta)

My code is now:

    from torch.export._trace import _export
    dim_0 = torch.export.Dim("dim_0", min=1)
    dim_1 = torch.export.Dim("dim_1", min=1)
    dim_2 = torch.export.Dim("dim_2", min=135, max=2160)
    dim_3 = torch.export.Dim("dim_3", min=135, max=3840)
    dynamic_shapes = {'image1': {2: dim_2, 3: dim_3}, 'image2': {2: dim_2, 3: dim_3},}
    exported_module = _export(
        model,
        (images_iter_1[0][0], images_iter_1[0][1]),
        dynamic_shapes=dynamic_shapes,
        strict=False,
        allow_complex_guards_as_runtime_asserts=True,
    )
   
    trt_model = trt.dynamo.compile(
        exported_module,
        inputs=[(images[0][0], images[0][1])], # shapes [1, 1, 135, 135]
        enabled_precisions={torch.float32},
        truncate_double=True,
        device=torch.device("cuda:0"),
        disable_tf32=True,
        use_explicit_typing=True,
        use_fp32_acc=True,
    )

Thanks a lot for your help.

@narendasan
Copy link
Collaborator

This seems like at torch-tensorrt bug, we can take a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants