Skip to content

Commit

Permalink
Improved error message for missing class_weights. (keras-team#7238)
Browse files Browse the repository at this point in the history
The error is now a ValueError, as per docs, and contains
a suitable message.
  • Loading branch information
jorgecarleitao authored and fchollet committed Jul 5, 2017
1 parent 59cd1c3 commit 985c441
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
13 changes: 12 additions & 1 deletion keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,18 @@ def _standardize_weights(y, sample_weight=None, class_weight=None,
y_classes = np.reshape(y, y.shape[0])
else:
y_classes = y
weights = np.asarray([class_weight[cls] for cls in y_classes])

weights = np.asarray([class_weight[cls] for cls in y_classes
if cls in class_weight])

if len(weights) != len(y_classes):
# subtract the sets to pick all missing classes
existing_classes = set(y_classes)
existing_class_weight = set(class_weight.keys())
raise ValueError('`class_weight` must contain all classes in the data.'
' The classes %s exist in the data but not in '
'`class_weight`.'
% (existing_classes - existing_class_weight))
return weights
else:
if sample_weight_mode is None:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_loss_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,20 @@ def test_sequential_temporal_sample_weights():
assert(score < standard_score_sequential)


@keras_test
def test_class_weight_wrong_classes():
model = create_sequential_model()
model.compile(loss=loss, optimizer='rmsprop')

(x_train, y_train), (x_test, y_test), (sample_weight, class_weight, test_ids) = _get_test_data()

del class_weight[1]
try:
model.fit(x_train, y_train, epochs=0, verbose=0, class_weight=class_weight)
assert False
except ValueError:
pass # expected behavior is to raise a ValueError with a suitable message


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 985c441

Please sign in to comment.