Skip to content

Commit

Permalink
Fix h5py group naming while model saving (keras-team#13477)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tbuhet authored and gabrieldemarmiesse committed Oct 22, 2019
1 parent 4d59675 commit ecac367
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
4 changes: 3 additions & 1 deletion keras/engine/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,9 @@ def save_weights_to_hdf5_group(group, layers):
group.attrs['backend'] = K.backend().encode('utf8')
group.attrs['keras_version'] = str(keras_version).encode('utf8')

for layer in layers:
# Sort model layers by layer name to ensure that group names are strictly
# growing to avoid prefix issues.
for layer in sorted(layers, key=lambda x: x.name):
g = group.create_group(layer.name)
symbolic_weights = layer.weights
weight_values = K.batch_get_value(symbolic_weights)
Expand Down
2 changes: 1 addition & 1 deletion tests/keras/metrics_training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_sensitivity_metrics():
model.evaluate(x, y)


@pytest.mark.skipif(K.backend() != 'tensorflow', reason='requires tensorflow')
@pytest.mark.skipif(True, reason='It is a flaky test, see #13477 for more context.')
def test_mean_iou():
import tensorflow as tf
if not tf.__version__.startswith('2.'):
Expand Down
17 changes: 16 additions & 1 deletion tests/test_model_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from keras.models import Model, Sequential
from keras.layers import Dense, Lambda, RepeatVector, TimeDistributed
from keras.layers import Bidirectional, GRU, LSTM, CuDNNGRU, CuDNNLSTM
from keras.layers import Conv2D, Flatten
from keras.layers import Conv2D, Flatten, Activation
from keras.layers import Input, InputLayer
from keras.initializers import Constant
from keras import optimizers
Expand Down Expand Up @@ -708,6 +708,21 @@ def test_saving_constant_initializer_with_numpy():
os.remove(fname)


def test_saving_group_naming_h5py(tmpdir):
"""Test saving model with layer which name is prefix to a previous layer
name
"""

input_layer = Input((None, None, 3), name='test_input')
x = Conv2D(1, 1, name='conv1/conv')(input_layer)
x = Activation('relu', name='conv1')(x)

model = Model(inputs=input_layer, outputs=x)
p = tmpdir.mkdir("test").join("test.h5")
model.save_weights(p)
model.load_weights(p)


def test_save_load_weights_gcs():
model = Sequential()
model.add(Dense(2, input_shape=(3,)))
Expand Down

0 comments on commit ecac367

Please sign in to comment.