Skip to content

Commit 1ea33b9

Browse files
committed
weight_metrics -> weigh_metrics
1 parent 77675d4 commit 1ea33b9

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

keras/engine/training.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ class Model(Container):
564564
"""
565565

566566
def compile(self, optimizer, loss, metrics=None, loss_weights=None,
567-
sample_weight_mode=None, weight_metrics=False, **kwargs):
567+
sample_weight_mode=None, weigh_metrics=False, **kwargs):
568568
"""Configures the model for training.
569569
570570
# Arguments
@@ -597,7 +597,7 @@ def compile(self, optimizer, loss, metrics=None, loss_weights=None,
597597
If the model has multiple outputs, you can use a different
598598
`sample_weight_mode` on each output by passing a
599599
dictionary or a list of modes.
600-
weight_metrics: bool whether or not to apply `sample_weight` or
600+
weigh_metrics: bool whether or not to apply `sample_weight` or
601601
`class_weight` to the supplied metrics during training and testing
602602
**kwargs: when using the Theano/CNTK backends, these arguments
603603
are passed into K.function. When using the TensorFlow backend,
@@ -835,7 +835,7 @@ def append_metric(layer_num, metric_name, metric_tensor):
835835
continue
836836
y_true = self.targets[i]
837837
y_pred = self.outputs[i]
838-
weights = sample_weights[i] if weight_metrics else None
838+
weights = sample_weights[i] if weigh_metrics else None
839839
output_metrics = nested_metrics[i]
840840
for metric in output_metrics:
841841
if metric == 'accuracy' or metric == 'acc':

tests/test_loss_weighting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_class_weight_wrong_classes():
165165
@keras_test
166166
def test_sample_weights_with_weighted_metrics():
167167
model = create_sequential_model()
168-
model.compile(loss=loss, optimizer='rmsprop', metrics=[loss], weight_metrics=True)
168+
model.compile(loss=loss, optimizer='rmsprop', metrics=[loss], weigh_metrics=True)
169169

170170
(x_train, y_train), (x_test, y_test), (sample_weight, class_weight, test_ids) = _get_test_data()
171171

0 commit comments

Comments
 (0)