Description
🐛 Describe the bug
Trying to export a simple ConvNet for CIFAR-10.
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
The export script is here. The error happens during to_edge()
call.
Traceback (most recent call last):
File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 81, in <module>
main()
File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 68, in main
ep = _export_model()
File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 44, in _export_model
ep = to_edge(ep)
File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 101, in wrapper
return func(self, *args, **kwargs)
File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 1217, in to_edge
edge_programs[name] = _generate_edge_program(name, config, program)
File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 799, in _generate_edge_program
edge_program = ExportedProgram(
File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 916, in __init__
self.validate()
File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 1466, in validate
self._validate()
File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 1475, in _validate
v().check(self)
File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 166, in check
self._check_graph_module(ep.graph_module)
File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 290, in _check_graph_module
self.check_additional(gm)
File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 284, in check_additional
_check_tensor_args_matching_op_allowed_dtype(gm)
File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 205, in _check_tensor_args_matching_op_allowed_dtype
raise SpecViolationError(
torch._export.verifier.SpecViolationError: These operators are taking Tensor inputs with mismatched dtypes:
Operator: <EdgeOpOverload: aten.convolution_backward.default>: schema = aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) with args: {'grad_output': torch.float32, 'input': torch.float32, 'weight': torch.float32, '__ret_0': torch.float32, '__ret_1': torch.float32}
stack trace: File "<eval_with_key>.21 from /localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1180 in wrapped", line 124, in forward
convolution_backward_1 = torch.ops.aten.convolution_backward.default(where_7, primals_11, primals_1, [6], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]); where_7 = primals_11 = primals_1 = None
Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding outputs.
Versions
Collecting environment information...
PyTorch version: 2.7.0.dev20250131+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.4
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-130-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe
GPU 2: NVIDIA A100 80GB PCIe
GPU 3: NVIDIA A100 80GB PCIe
GPU 4: NVIDIA A100 80GB PCIe
GPU 5: NVIDIA A100 80GB PCIe
GPU 6: NVIDIA A100 80GB PCIe
GPU 7: NVIDIA A100 80GB PCIe
Nvidia driver version: 550.120
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7H12 64-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 2
Stepping: 0
Frequency boost: enabled
CPU max MHz: 2600.0000
CPU min MHz: 1500.0000
BogoMIPS: 5199.82
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 64 MiB (128 instances)
L3 cache: 512 MiB (32 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-63
NUMA node1 CPU(s): 64-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT disabled
Vulnerability Spec rstack overflow: Mitigation; SMT disabled
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] executorch==0.6.0a0+1fda542
[pip3] numpy==2.0.0
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.7.0.dev20250131+cpu
[pip3] torchao==0.8.0+git11333ba2
[pip3] torchaudio==2.6.0.dev20250131+cpu
[pip3] torchsr==1.0.4
[pip3] torchtune==0.5.0
[pip3] torchvision==0.22.0.dev20250131+cpu
[pip3] triton==3.1.0
[conda] Could not collect