Skip to content

'jit_compile=True' raises InternalTorchDynamoError with EfficientNetV2 models on torch backend. #21647

@Doch88

Description

@Doch88

As per title, compiling a Keras EfficientNetV2 model using torch backend raises an InternalTorchDynamoError, see below.

Environment (running on Colab, with GPU):
Keras version: 3.11.3
Torch version: 2.8.0-cu126

Minimal reproducible example:

import os
os.environ["KERAS_BACKEND"] = 'torch'

import numpy as np
import torch
import keras
from keras import layers, models
from keras.applications import EfficientNetV2B2
from keras.optimizers import Adam
from keras.losses import CategoricalCrossentropy

print(f"Backend: {keras.config.backend()}")

num_classes = 10
batch_size = 16
steps_per_epoch = 5
epochs = 2

# Generate random data
data_shape = (224, 224, 3)
x_train = np.random.rand(batch_size * steps_per_epoch, *data_shape).astype(np.float32)
y_train = np.random.randint(0, num_classes, size=(batch_size * steps_per_epoch,))
y_train = np.eye(num_classes)[y_train] 

base_model = EfficientNetV2B2(include_top=False, input_shape=(None, None, 3), pooling='avg', include_preprocessing=True)
x = base_model.output
output = layers.Dense(num_classes, activation='softmax')(x)
model = models.Model(inputs=base_model.input, outputs=output)

model.compile(optimizer=Adam(learning_rate=0.001), loss=CategoricalCrossentropy(), metrics=['accuracy'], jit_compile=True)

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs)

Output:

Backend: torch
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/efficientnet_v2/efficientnetv2-b2_notop.h5
35839040/35839040 ━━━━━━━━━━━━━━━━━━━━ 3s 0us/step
Epoch 1/2
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1464: UserWarning: Dynamo cannot trace optree C/C++ function optree._C.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.is_leaf.
 Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1464: UserWarning: Dynamo cannot trace optree C/C++ function optree._C.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.flatten.
 Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
W0909 14:47:48.027000 291 torch/_inductor/utils.py:1436] [102/0] Not enough SMs to use max_autotune_gemm mode
W0909 14:48:07.288000 291 torch/_dynamo/convert_frame.py:1016] [3/8] torch._dynamo hit config.recompile_limit (8)
W0909 14:48:07.288000 291 torch/_dynamo/convert_frame.py:1016] [3/8]    function: '__call__' (/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py:816)
W0909 14:48:07.288000 291 torch/_dynamo/convert_frame.py:1016] [3/8]    last reason: 3/7: expected type of 'args[0]' to be a tensor type, ' but found <class 'list'>
W0909 14:48:07.288000 291 torch/_dynamo/convert_frame.py:1016] [3/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0909 14:48:07.288000 291 torch/_dynamo/convert_frame.py:1016] [3/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
---------------------------------------------------------------------------
InternalTorchDynamoError                  Traceback (most recent call last)
[/tmp/ipython-input-2007246960.py](https://localhost:8080/#) in <cell line: 0>()
     32 print("Model compiled!")
     33 
---> 34 model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs)

23 frames
[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py](https://localhost:8080/#) in compute_exception_table(instructions)
    894 
    895     # Sort keys by increasing start, then decreasing end
--> 896     keys_sorted = sorted(exn_dict.keys(), key=lambda t: (t[0], -t[1]))
    897     # smallest byte that the next exception table entry can start at
    898     nexti = 0

InternalTorchDynamoError: TypeError: '<' not supported between instances of 'NoneType' and 'int'

from user code:
   File "/usr/local/lib/python3.12/dist-packages/keras/src/trainers/compile_utils.py", line 693, in call
    if not tree.is_nested(y_true) and not tree.is_nested(y_pred):

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Interestingly enough, switching to CPU (same versions of the libraries) the error changes:

Backend: torch
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/efficientnet_v2/efficientnetv2-b2_notop.h5
35839040/35839040 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Epoch 1/2
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1464: UserWarning: Dynamo cannot trace optree C/C++ function optree._C.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.is_leaf.
 Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1464: UserWarning: Dynamo cannot trace optree C/C++ function optree._C.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.flatten.
 Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
W0909 14:55:17.591000 270 torch/utils/cpp_extension.py:118] [93/0] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[/tmp/ipython-input-360142954.py](https://localhost:8080/#) in <cell line: 0>()
     30 model.compile(optimizer=Adam(learning_rate=0.001), loss=CategoricalCrossentropy(), metrics=['accuracy'], jit_compile=True)
     31 
---> 32 model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs)

38 frames
[/usr/local/lib/python3.12/dist-packages/sympy/core/relational.py](https://localhost:8080/#) in __bool__(self)
    514 
    515     def __bool__(self):
--> 516         raise TypeError("cannot determine truth value of Relational")
    517 
    518     def _eval_as_set(self):

RuntimeError: Exception encountered when calling Conv2D.call().

TypeError: cannot determine truth value of Relational

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


Arguments received by Conv2D.call():
  • inputs=torch.Tensor(shape=torch.Size([16, 112, 112, 32]), dtype=float32)

What I tried (and it didn't work):

  • Fixed input_shape (instead of Nones) on the model.
  • Different versions of EfficientNetV2.
  • Using torch.nn losses.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions