Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError on torch.unqiue_consecutive with torch.compile( fullgraph = true) #113118

Closed
zoux1a opened this issue Nov 7, 2023 · 7 comments
Closed
Labels
bug module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@zoux1a
Copy link

zoux1a commented Nov 7, 2023

🐛 Describe the bug

On eager mode, the func worked as expected behavior. However, encountered a DynamicOutputShapeException(func) with torch.compile that indicates there is a potential bug.

import torch

def forward(x):
  return torch.unique_consecutive(dim=0,input=x)        

x = torch.rand([2],dtype=torch.float32)# generate arg
forward(x)# on eagermode
print("build succeeded")
torch.compile(forward, fullgraph=True)(x)# on torch.compile mode(with fullgrah=True)

error trace:

build succeeded
......
raise DynamicOutputShapeException(func)
torch._subclasses.fake_tensor.DynamicOutputShapeException: aten.unique_consecutive.default

The above exception was the direct cause of the following exception:
......
 raise DynamicOutputShapeException(func)
RuntimeError: Failed running call_function <function boolean_dispatch.<locals>.fn at 0x7f0bf54118b0>(*(), **{'dim': 0, 'input': FakeTensor(..., size=(2,))}):
aten.unique_consecutive.default

During handling of the above exception, another exception occurred:
......
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: dynamic shape operator: aten.unique_consecutive.default
......
build succeeded
Traceback (most recent call last):
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1536, in run_node
    return node.target(*args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_jit_internal.py", line 499, in fn
    return if_false(*args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_jit_internal.py", line 499, in fn
    return if_false(*args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/functional.py", line 1053, in _consecutive_return_output
    output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/functional.py", line 970, in _unique_consecutive_impl
    output, inverse_indices, counts = _VF.unique_consecutive(  # type: ignore[attr-defined]
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1390, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1645, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 540, in dyn_shape
    raise DynamicOutputShapeException(func)
torch._subclasses.fake_tensor.DynamicOutputShapeException: aten.unique_consecutive.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1451, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 994, in wrap_fake_exception
    return fn()
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1452, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1557, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1536, in run_node
    return node.target(*args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_jit_internal.py", line 499, in fn
    return if_false(*args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_jit_internal.py", line 499, in fn
    return if_false(*args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/functional.py", line 1053, in _consecutive_return_output
    output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/functional.py", line 970, in _unique_consecutive_impl
    output, inverse_indices, counts = _VF.unique_consecutive(  # type: ignore[attr-defined]
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1390, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1645, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 540, in dyn_shape
    raise DynamicOutputShapeException(func)
RuntimeError: Failed running call_function <function boolean_dispatch.<locals>.fn at 0x7f0bf54118b0>(*(), **{'dim': 0, 'input': FakeTensor(..., size=(2,))}):
aten.unique_consecutive.default

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/guihuan/LLM/results/torch-2/2023-11-03-15-06/repros/repro26.py", line 12, in <module>
    torch.compile(forward, fullgraph=True)(x)# on torch.compile mode(with fullgran=True)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 571, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 408, in _convert_frame_assert
    return _compile(
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 619, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 221, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 536, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 149, in _fn
    return fn(*args, **kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 504, in transform
    tracer.run()
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2111, in run
    super().run()
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in run
    and self.step()
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 712, in step
    getattr(self, inst.opname)(inst)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 407, in wrapper
    return inner_fn(self, inst)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1200, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 584, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/variables/torch.py", line 712, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1357, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1447, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1468, in get_fake_value
    unimplemented(f"dynamic shape operator: {cause.func}")
  File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/torch/_dynamo/exc.py", line 184, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: dynamic shape operator: aten.unique_consecutive.default

from user code:
   File "/home/guihuan/LLM/results/torch-2/2023-11-03-15-06/repros/repro26.py", line 7, in forward
    return torch.unique_consecutive(dim=0,input=x)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

Collecting environment information...
PyTorch version: 2.2.0.dev20231105+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.9.18 (main, Sep 11 2023, 13:41:44) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2070
GPU 1: NVIDIA GeForce RTX 2070
GPU 2: NVIDIA GeForce RTX 2070
GPU 3: NVIDIA GeForce RTX 2070

Nvidia driver version: 535.104.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 32
On-line CPU(s) list: 0-31
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU E5-2630 v3 @ 2.40GHz
CPU family: 6
Model: 63
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 2
Stepping: 2
CPU max MHz: 3200.0000
CPU min MHz: 1200.0000
BogoMIPS: 4794.64
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm cpuid_fault epb invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm xsaveopt cqm_llc cqm_occup_llc dtherm ida arat pln pts md_clear flush_l1d
Virtualization: VT-x
L1d cache: 512 KiB (16 instances)
L1i cache: 512 KiB (16 instances)
L2 cache: 4 MiB (16 instances)
L3 cache: 40 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30
NUMA node1 CPU(s): 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.1
[pip3] pytorch-triton==2.1.0+6e4932cda8
[pip3] torch==2.2.0.dev20231105+cu118
[pip3] torchaudio==2.2.0.dev20231105+cu118
[pip3] torchvision==0.17.0.dev20231105+cu118
[conda] cudatoolkit 11.8.0 h6a678d5_0 defaults
[conda] numpy 1.26.1 pypi_0 pypi
[conda] pytorch-triton 2.1.0+6e4932cda8 pypi_0 pypi
[conda] torch 2.2.0.dev20231105+cu118 pypi_0 pypi
[conda] torchaudio 2.2.0.dev20231105+cu118 pypi_0 pypi
[conda] torchvision 0.17.0.dev20231105+cu118 pypi_0 pypi

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519

@ezyang
Copy link
Contributor

ezyang commented Nov 8, 2023

this should be relatively easy to fix, we need to induce a graph break in this case (for most data dependent this already happens, so we'll need to see why it's not here)

@wconstab wconstab added bug triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 13, 2023
@ezyang
Copy link
Contributor

ezyang commented Nov 20, 2023

Actually not a bug, do the config capture_dynamic_output_shape_ops = True

@a-gardner1
Copy link
Contributor

a-gardner1 commented May 10, 2024

@ezyang I just encountered a similar issue with torch.unique using torch.onnx.dynamo_export (minimal repro below). The suggested configuration capture_dynamic_output_shape_ops is not an option for torch.onnx.dynamo_export. The closest option that I am aware of is torch.onnx.ExportOptions(dynamic_shapes=True), which does not make any difference in the observed error.

Edit: I discovered that capture_dynamic_output_shape_ops is a variable in torch._dynamo.config. Setting it to True makes no difference. This is in torch==2.3.0.

import torch
import torch.onnx
torch.onnx.dynamo_export(lambda x: torch.unique(x), torch.arange(10))

error trace:

.../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only 
implements opset version 18 for now. If you need to use a different opset version, please register them with 
register_custom_op.
  warnings.warn(
Traceback (most recent call last):
  File ".../lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1571, in run_node
    return node.target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_jit_internal.py", line 499, in fn
    return if_false(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_jit_internal.py", line 499, in fn
    return if_false(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/functional.py", line 991, in _return_output
    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/functional.py", line 905, in _unique_impl
    output, inverse_indices, counts = torch._unique2(
                                      ^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1649, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 523, in dyn_shape
    raise DynamicOutputShapeException(func)
torch._subclasses.fake_tensor.DynamicOutputShapeException: aten._unique2.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1486, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1027, in wrap_fake_exception
    return fn()
           ^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1487, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1592, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File ".../lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1571, in run_node
    return node.target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_jit_internal.py", line 499, in fn
    return if_false(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_jit_internal.py", line 499, in fn
    return if_false(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/functional.py", line 991, in _return_output
    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/functional.py", line 905, in _unique_impl
    output, inverse_indices, counts = torch._unique2(
                                      ^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1649, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 523, in dyn_shape
    raise DynamicOutputShapeException(func)
RuntimeError: Failed running call_function <function boolean_dispatch.<locals>.fn at 0x7f090234c180>(*(FakeTensor(..., size=(10,), dtype=torch.int64),), **{}):
aten._unique2.default

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py", line 1433, in dynamo_export
    ).export()
      ^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py", line 1175, in export
    graph_module = self.options.fx_tracer.generate_fx(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 213, in generate_fx
    graph_module, graph_guard = torch._dynamo.export(
                                ^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1355, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 168, in wrapped
    return output_adapter.apply(model_func(*args, **kwargs), model=model)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
                       ^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File ".../lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
        ^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1802, in CALL
    self.call_function(fn, args, kwargs)
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/variables/torch.py", line 542, in call_function
    tensor_variable = wrap_fx_proxy(
                      ^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1314, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1399, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1503, in get_fake_value
    unimplemented(f"dynamic shape operator: {cause.func}")
  File ".../lib/python3.11/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: dynamic shape operator: aten._unique2.default

from user code:
   File "...", line 713, in <lambda>
    torch.onnx.dynamo_export(lambda x: torch.unique(x), torch.arange(10))

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../lib/python3.11/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, main_globals, None,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/runpy.py", line 88, in _run_code
    exec(code, run_globals)
  File ".../.vscode-server/extensions/ms-python.python-2024.4.1/python_files/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File ".../.vscode-server/extensions/ms-python.python-2024.4.1/python_files/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File ".../.vscode-server/extensions/ms-python.python-2024.4.1/python_files/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File ".../.vscode-server/extensions/ms-python.python-2024.4.1/python_files/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../.vscode-server/extensions/ms-python.python-2024.4.1/python_files/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File ".../.vscode-server/extensions/ms-python.python-2024.4.1/python_files/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "...", line 713, in <module>
    torch.onnx.dynamo_export(lambda x: torch.unique(x), torch.arange(10))
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py", line 1444, in dynamo_export
    raise OnnxExporterError(
torch.onnx.OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 
'report_dynamo_export.sarif'. SARIF is a standard format for the output of static analysis tools. SARIF logs can be loaded in
VS Code SARIF viewer extension, or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). Please report a 
bug on PyTorch Github: https://github.com/pytorch/pytorch/issues

@ezyang
Copy link
Contributor

ezyang commented May 11, 2024

unique2 was aded in #124306, try a nightly

@ezyang ezyang closed this as completed May 11, 2024
@a-gardner1
Copy link
Contributor

Using the nightly now gives a different error (so progress?):

Traceback (most recent call last):
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py", line 1503, in dynamo_export
    ).export()
      ^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py", line 1236, in export
    graph_module = self.options.fx_tracer.generate_fx(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 232, in generate_fx
    return self.pre_export_passes(options, model, graph_module, updated_model_args)  # type: ignore[return-value]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 242, in pre_export_passes
    return exporter.common_pre_export_passes(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py", line 1559, in common_pre_export_passes
    ).analyze(infra.levels.ERROR)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 85, in analyze
    self._lint(analysis_result, diagnostic_level)
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 37, in _lint
    self.diagnostic_context.log_and_raise_if_error(diagnostic)
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/diagnostics/infra/context.py", line 367, in log_and_raise_if_error
    raise RuntimeErrorWithDiagnostic(diagnostic)
torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten._unique2.default']}. 

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../lib/python3.11/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, main_globals, None,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/runpy.py", line 88, in _run_code
    exec(code, run_globals)
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "...", line 714, in <module>
    torch.onnx.dynamo_export(lambda x: torch.unique(x), torch.arange(10))
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py", line 1514, in dynamo_export
    raise OnnxExporterError(
torch.onnx.OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 'report_dynamo_export.sarif'.
SARIF is a standard format for the output of static analysis tools. SARIF logs can be loaded in VS Code SARIF viewer extension,
or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). Please report a bug on PyTorch Github: 
https://github.com/pytorch/pytorch/issues

@a-gardner1
Copy link
Contributor

@ezyang I have locally fixed this in onnxscript and successfully exported torch.unique to an ONNX program. Should I create a PR?

@ezyang
Copy link
Contributor

ezyang commented May 14, 2024

Yes please!

pytorchmergebot pushed a commit that referenced this issue Jun 1, 2024
Follow-up to #113118 and #124306.

Developed in coordination with the solution to microsoft/onnxscript#1547

This PR adds the missing fake tensor implementation for `aten.unique_dim`, thus enabling tracing and compilation of `torch.unique` when `dim` is not None.

Local testing has proceeded with the following simple script (provided that one has checked out the changes in microsoft/onnxscript#1547):

```python
    import onnx
    import onnxruntime as ort
    import logging
    import numpy as np
    onnx_program = torch.onnx.dynamo_export(
        lambda x: torch.unique(x,
                               dim=0,
                               return_inverse=True),
        torch.arange(10),
        export_options=torch.onnx.ExportOptions(
            dynamic_shapes=True,
            diagnostic_options=torch.onnx.DiagnosticOptions(
                verbosity_level=logging.DEBUG)))
    onnx_program.save("torch_unique.onnx")
    onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10))
    onnx_outputs = onnx_program(*onnx_inputs)
    loaded_onnx_program = onnx.load("torch_unique.onnx")
    onnx.checker.check_model(loaded_onnx_program)
    ort_session = ort.InferenceSession("torch_unique.onnx")
    inputs = np.random.randint(0, 10, 10)
    print(f"Inputs: {inputs}")
    outputs = ort_session.run(None,
                              {
                                  "l_x_": inputs
                              })
    print(f"Outputs: {outputs}")
    print("Success")
```

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #126561
Approved by: https://github.com/ezyang
petrex pushed a commit to petrex/pytorch that referenced this issue Jun 5, 2024
Follow-up to pytorch#113118 and pytorch#124306.

Developed in coordination with the solution to microsoft/onnxscript#1547

This PR adds the missing fake tensor implementation for `aten.unique_dim`, thus enabling tracing and compilation of `torch.unique` when `dim` is not None.

Local testing has proceeded with the following simple script (provided that one has checked out the changes in microsoft/onnxscript#1547):

```python
    import onnx
    import onnxruntime as ort
    import logging
    import numpy as np
    onnx_program = torch.onnx.dynamo_export(
        lambda x: torch.unique(x,
                               dim=0,
                               return_inverse=True),
        torch.arange(10),
        export_options=torch.onnx.ExportOptions(
            dynamic_shapes=True,
            diagnostic_options=torch.onnx.DiagnosticOptions(
                verbosity_level=logging.DEBUG)))
    onnx_program.save("torch_unique.onnx")
    onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10))
    onnx_outputs = onnx_program(*onnx_inputs)
    loaded_onnx_program = onnx.load("torch_unique.onnx")
    onnx.checker.check_model(loaded_onnx_program)
    ort_session = ort.InferenceSession("torch_unique.onnx")
    inputs = np.random.randint(0, 10, 10)
    print(f"Inputs: {inputs}")
    outputs = ort_session.run(None,
                              {
                                  "l_x_": inputs
                              })
    print(f"Outputs: {outputs}")
    print("Success")
```

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: pytorch#126561
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants