Skip to content

Compatibility issue between add_loss and TimeDistributed when using JAX backend #21605

@AmedeoBiolatti

Description

@AmedeoBiolatti

When using JAX backend with jit enabled, I am getting a TracerBoolConversionError when trying to use a custom layer with add_loss inside a TimeDistributed wrapper inside a multi-input functional model.

This is a minimal example that raises an error. Running the code twice will make the error go away

import os

os.environ["KERAS_BACKEND"] = "jax"

import keras
import numpy as np


class CustomLayer(keras.layers.Layer):
    def __init__(self, ):
        super().__init__()

    def call(self, x):
        self.add_loss(1e-6 * keras.ops.mean(x ** 2))
        return x

    def compute_output_shape(self, input_shape):
        return input_shape


X = keras.layers.Input(shape=(3, 2))
Z = keras.layers.Input(shape=(3, 2))

encoder = keras.layers.TimeDistributed(CustomLayer())
E = encoder(X) + encoder(Z)
Y = keras.layers.Dense(1)(E)

model = keras.Model(
    inputs=[X, Z],
    outputs=Y
)

model.compile(
    loss="mse",
    optimizer="adam",
    jit_compile=True
)
h = model.fit(
    [np.random.normal(size=(100, 3, 2)), np.random.normal(size=(100, 3, 2))],
    np.random.normal(size=(100, 3, 1))
)

I tested in the following configurations
keras: 3.8.0, 3.11.2
jax: 0.5.2, 0.7.1

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions