Skip to content

Simple ConvNet causes mismatched dtypes during to_edge() call #8206

Open
@holgerroth

Description

@holgerroth

🐛 Describe the bug

Trying to export a simple ConvNet for CIFAR-10.

class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

The export script is here. The error happens during to_edge() call.

Traceback (most recent call last):
  File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 81, in <module>
    main()
  File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 68, in main
    ep = _export_model()
  File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 44, in _export_model
    ep = to_edge(ep)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 101, in wrapper
    return func(self, *args, **kwargs)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 1217, in to_edge
    edge_programs[name] = _generate_edge_program(name, config, program)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 799, in _generate_edge_program
    edge_program = ExportedProgram(
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 916, in __init__
    self.validate()
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 1466, in validate
    self._validate()
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 1475, in _validate
    v().check(self)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 166, in check
    self._check_graph_module(ep.graph_module)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 290, in _check_graph_module
    self.check_additional(gm)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 284, in check_additional
    _check_tensor_args_matching_op_allowed_dtype(gm)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 205, in _check_tensor_args_matching_op_allowed_dtype
    raise SpecViolationError(
torch._export.verifier.SpecViolationError: These operators are taking Tensor inputs with mismatched dtypes:

Operator: <EdgeOpOverload: aten.convolution_backward.default>: schema = aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) with args: {'grad_output': torch.float32, 'input': torch.float32, 'weight': torch.float32, '__ret_0': torch.float32, '__ret_1': torch.float32}
stack trace: File "<eval_with_key>.21 from /localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1180 in wrapped", line 124, in forward
    convolution_backward_1 = torch.ops.aten.convolution_backward.default(where_7, primals_11, primals_1, [6], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]);  where_7 = primals_11 = primals_1 = None
Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding outputs.

Versions

Collecting environment information...
PyTorch version: 2.7.0.dev20250131+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.4
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-130-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe
GPU 2: NVIDIA A100 80GB PCIe
GPU 3: NVIDIA A100 80GB PCIe
GPU 4: NVIDIA A100 80GB PCIe
GPU 5: NVIDIA A100 80GB PCIe
GPU 6: NVIDIA A100 80GB PCIe
GPU 7: NVIDIA A100 80GB PCIe

Nvidia driver version: 550.120
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7H12 64-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 2
Stepping: 0
Frequency boost: enabled
CPU max MHz: 2600.0000
CPU min MHz: 1500.0000
BogoMIPS: 5199.82
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 64 MiB (128 instances)
L3 cache: 512 MiB (32 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-63
NUMA node1 CPU(s): 64-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT disabled
Vulnerability Spec rstack overflow: Mitigation; SMT disabled
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] executorch==0.6.0a0+1fda542
[pip3] numpy==2.0.0
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.7.0.dev20250131+cpu
[pip3] torchao==0.8.0+git11333ba2
[pip3] torchaudio==2.6.0.dev20250131+cpu
[pip3] torchsr==1.0.4
[pip3] torchtune==0.5.0
[pip3] torchvision==0.22.0.dev20250131+cpu
[pip3] triton==3.1.0
[conda] Could not collect

cc @JacobSzwejbka

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: trainingIssues related to training models on edge devicestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions