-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Open
Labels
Description
When using JAX backend with jit enabled, I am getting a TracerBoolConversionError
when trying to use a custom layer with add_loss
inside a TimeDistributed
wrapper inside a multi-input functional model.
This is a minimal example that raises an error. Running the code twice will make the error go away
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
import numpy as np
class CustomLayer(keras.layers.Layer):
def __init__(self, ):
super().__init__()
def call(self, x):
self.add_loss(1e-6 * keras.ops.mean(x ** 2))
return x
def compute_output_shape(self, input_shape):
return input_shape
X = keras.layers.Input(shape=(3, 2))
Z = keras.layers.Input(shape=(3, 2))
encoder = keras.layers.TimeDistributed(CustomLayer())
E = encoder(X) + encoder(Z)
Y = keras.layers.Dense(1)(E)
model = keras.Model(
inputs=[X, Z],
outputs=Y
)
model.compile(
loss="mse",
optimizer="adam",
jit_compile=True
)
h = model.fit(
[np.random.normal(size=(100, 3, 2)), np.random.normal(size=(100, 3, 2))],
np.random.normal(size=(100, 3, 1))
)
I tested in the following configurations
keras: 3.8.0, 3.11.2
jax: 0.5.2, 0.7.1