Skip to content

Issue loading model containing Dense layer with Identity initializer #20483

Closed
@kristoferm94

Description

@kristoferm94

Hi, I've been migrating some model code from Keras 2 to Keras 3, and I think I stumbled upon a bug.

I've noticed that if I save a model containing a dense layer with an identity kernel initializer in Keras 3+, I cannot reload the model. I get an exception that says Keras cannot interpret the initializer identifier in config.json in the Keras model file.
A snippet from the exception (rest at the bottom of this post):

Exception encountered: Could not interpret initializer identifier: {'module': 'keras.initializers', 'class_name': 'IdentityInitializer', 'config': {}, 'registered_name': None}

I have tried this out on Windows using Keras 3.6 + Jax 0.4.35 and Ubuntu using Keras 3.6 + Tensorflow 2.18.0, and I get the same exception.

Currently, I am using a hacky workaround to replace the faulty 'IdentityInitializer' string in the config.json in the Keras model file with 'Identity'

Here are pytest tests for replicating this issue (see test_save_read_dense_layer_model_with_identity_initializer for replicating this exception, see attached
testoutput.txt
file for test output which includes the full exception traceback):

from pathlib import Path
from tempfile import TemporaryDirectory
from zipfile import ZipFile

import keras


def load_model_with_workaround(model_path: Path) -> keras.Model:
    with (
        TemporaryDirectory() as tmp_dir,
        ZipFile(model_path, "r") as original_model_file,
    ):
        new_model_path = Path(tmp_dir) / "new.keras"
        with ZipFile(new_model_path, "w") as new_model_file:
            for file_name in original_model_file.namelist():
                original_data = original_model_file.read(file_name)

                if file_name == "config.json":
                    original_data = (
                        original_data.decode("utf-8")
                        .replace(
                            'class_name": "IdentityInitializer"',
                            'class_name": "Identity"',
                        )
                        .encode("utf-8")
                    )

                with new_model_file.open(file_name, "w") as f:
                    f.write(original_data)

        return keras.models.load_model(new_model_path)


# This test will fail
def test_save_read_dense_layer_model_with_identity_initializer() -> None:
    model = keras.Sequential(
        [
            keras.layers.Input((5,)),
            keras.layers.Dense(5, kernel_initializer=keras.initializers.Identity()),
        ]
    )
    with TemporaryDirectory() as tmp_dir:
        save_path = Path(tmp_dir) / "mymodel.keras"
        model.save(save_path)
        model_from_file = keras.models.load_model(save_path)


# This test will pass
def test_save_read_dense_layer_model_with_identity_initializer_using_workaround() -> None:
    model = keras.Sequential(
        [
            keras.layers.Input((5,)),
            keras.layers.Dense(5, kernel_initializer=keras.initializers.Identity()),
        ]
    )
    with TemporaryDirectory() as tmp_dir:
        save_path = Path(tmp_dir) / "mymodel.keras"
        model.save(save_path)
        model_from_file = load_model_with_workaround(save_path)

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions