Skip to content

Add weighted_metrics arg to compile #7536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 41 additions & 67 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,46 +465,6 @@ def weighted(y_true, y_pred, weights, mask=None):
return weighted


def _masked_objective(fn):
"""Adds support for masking to an objective function.

It transforms an objective function `fn(y_true, y_pred)`
into a cost-masked objective function
`fn(y_true, y_pred, mask)`.

# Arguments
fn: The objective function to wrap,
with signature `fn(y_true, y_pred)`.

# Returns
A function with signature `fn(y_true, y_pred, mask)`.
"""
def masked(y_true, y_pred, mask=None):
"""Wrapper function.

# Arguments
y_true: `y_true` argument of `fn`.
y_pred: `y_pred` argument of `fn`.
mask: Mask tensor.

# Returns
Scalar tensor.
"""
# score_array has ndim >= 2
score_array = fn(y_true, y_pred)
if mask is not None:
# Cast the mask to floatX to avoid float64 upcasting in theano
mask = K.cast(mask, K.floatx())
# mask should have the same shape as score_array
score_array *= mask
# the loss per batch should be proportional
# to the number of unmasked samples.
score_array /= K.mean(mask)

return K.mean(score_array)
return masked


def _standardize_weights(y, sample_weight=None, class_weight=None,
sample_weight_mode=None):
"""Performs sample weight validation and standardization.
Expand Down Expand Up @@ -604,7 +564,7 @@ class Model(Container):
"""

def compile(self, optimizer, loss, metrics=None, loss_weights=None,
sample_weight_mode=None, **kwargs):
sample_weight_mode=None, weighted_metrics=None, **kwargs):
"""Configures the model for training.

# Arguments
Expand Down Expand Up @@ -637,6 +597,8 @@ def compile(self, optimizer, loss, metrics=None, loss_weights=None,
If the model has multiple outputs, you can use a different
`sample_weight_mode` on each output by passing a
dictionary or a list of modes.
weighted_metrics: list of metrics to be evaluated and weighted
by sample_weight or class_weight during training and testing
**kwargs: when using the Theano/CNTK backends, these arguments
are passed into K.function. When using the TensorFlow backend,
these arguments are passed into `tf.Session.run`.
Expand Down Expand Up @@ -822,6 +784,7 @@ def compile(self, optimizer, loss, metrics=None, loss_weights=None,

# Prepare metrics.
self.metrics = metrics
self.weighted_metrics = weighted_metrics
self.metrics_names = ['loss']
self.metrics_tensors = []

Expand Down Expand Up @@ -860,6 +823,7 @@ def compile(self, optimizer, loss, metrics=None, loss_weights=None,
# List of same size as output_names.
# contains tuples (metrics for output, names of metrics).
nested_metrics = _collect_metrics(metrics, self.output_names)
nested_weighted_metrics = _collect_metrics(weighted_metrics, self.output_names)

def append_metric(layer_num, metric_name, metric_tensor):
"""Helper function used in loop below."""
Expand All @@ -871,36 +835,46 @@ def append_metric(layer_num, metric_name, metric_tensor):
for i in range(len(self.outputs)):
if i in skip_indices:
continue

y_true = self.targets[i]
y_pred = self.outputs[i]
weights = sample_weights[i]
output_metrics = nested_metrics[i]
for metric in output_metrics:
if metric == 'accuracy' or metric == 'acc':
# custom handling of accuracy
# (because of class mode duality)
output_shape = self.internal_output_shapes[i]
acc_fn = None
if (output_shape[-1] == 1 or
self.loss_functions[i] == losses.binary_crossentropy):
# case: binary accuracy
acc_fn = metrics_module.binary_accuracy
elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
# case: categorical accuracy with sparse targets
acc_fn = metrics_module.sparse_categorical_accuracy
else:
acc_fn = metrics_module.categorical_accuracy
output_weighted_metrics = nested_weighted_metrics[i]

def handle_metrics(metrics, weights=None):
metric_name_prefix = 'weighted_' if weights is not None else ''

for metric in metrics:
if metric == 'accuracy' or metric == 'acc':
# custom handling of accuracy
# (because of class mode duality)
output_shape = self.internal_output_shapes[i]
if (output_shape[-1] == 1 or
self.loss_functions[i] == losses.binary_crossentropy):
# case: binary accuracy
acc_fn = metrics_module.binary_accuracy
elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
# case: categorical accuracy with sparse targets
acc_fn = metrics_module.sparse_categorical_accuracy
else:
acc_fn = metrics_module.categorical_accuracy

masked_fn = _masked_objective(acc_fn)
append_metric(i, 'acc', masked_fn(y_true, y_pred, mask=masks[i]))
else:
metric_fn = metrics_module.get(metric)
masked_metric_fn = _masked_objective(metric_fn)
metric_result = masked_metric_fn(y_true, y_pred, mask=masks[i])
metric_result = {
metric_fn.__name__: metric_result
}
for name, tensor in six.iteritems(metric_result):
append_metric(i, name, tensor)
acc_fn = _weighted_masked_objective(acc_fn)
metric_name = metric_name_prefix + 'acc'
append_metric(i, metric_name, acc_fn(y_true, y_pred, weights=weights, mask=masks[i]))
else:
metric_fn = metrics_module.get(metric)
weighted_metric_fn = _weighted_masked_objective(metric_fn)
metric_result = weighted_metric_fn(y_true, y_pred, weights=weights, mask=masks[i])
metric_result = {
metric_fn.__name__: metric_result
}
for name, tensor in six.iteritems(metric_result):
append_metric(i, metric_name_prefix + name, tensor)

handle_metrics(output_metrics)
handle_metrics(output_weighted_metrics, weights=weights)

# Prepare gradient updates and state updates.
self.total_loss = total_loss
Expand Down
5 changes: 5 additions & 0 deletions keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ def save_weights(self, filepath, overwrite=True):
def compile(self, optimizer, loss,
metrics=None,
sample_weight_mode=None,
weighted_metrics=None,
**kwargs):
"""Configures the learning process.

Expand All @@ -760,6 +761,8 @@ def compile(self, optimizer, loss,
sample_weight_mode: if you need to do timestep-wise
sample weighting (2D weights), set this to "temporal".
"None" defaults to sample-wise weights (1D).
weighted_metrics: list of metrics to be evaluated and weighted
by sample_weight or class_weight during training and testing
**kwargs: for Theano/CNTK backends, these are passed into
K.function. When using the TensorFlow backend, these are
passed into `tf.Session.run`.
Expand All @@ -780,12 +783,14 @@ def compile(self, optimizer, loss,
self.model.compile(optimizer, loss,
metrics=metrics,
sample_weight_mode=sample_weight_mode,
weighted_metrics=weighted_metrics,
**kwargs)
self.optimizer = self.model.optimizer
self.loss = self.model.loss
self.total_loss = self.model.total_loss
self.loss_weights = self.model.loss_weights
self.metrics = self.model.metrics
self.weighted_metrics = self.model.weighted_metrics
self.metrics_tensors = self.model.metrics_tensors
self.metrics_names = self.model.metrics_names
self.sample_weight_mode = self.model.sample_weight_mode
Expand Down
129 changes: 127 additions & 2 deletions tests/test_loss_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import pytest
import numpy as np

from keras import backend as K
from keras.utils.test_utils import get_test_data
from keras.models import Sequential
from keras.layers import Dense, Activation, GRU, TimeDistributed
from keras.models import Sequential, Model
from keras.layers import Dense, Activation, GRU, TimeDistributed, Input
from keras.utils import np_utils
from keras.utils.test_utils import keras_test
from numpy.testing import assert_almost_equal, assert_array_almost_equal

num_classes = 10
batch_size = 128
Expand All @@ -19,9 +21,16 @@
timesteps = 3
input_dim = 10
loss = 'mse'
loss_full_name = 'mean_squared_error'
standard_weight = 1
standard_score_sequential = 0.5

decimal_precision = {
'cntk': 2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CNTK is only accurate to two decimal places. I'm not sure why.

'theano': 6,
'tensorflow': 6
}


def _get_test_data():
np.random.seed(1337)
Expand Down Expand Up @@ -148,6 +157,122 @@ def test_sequential_temporal_sample_weights():
assert(score < standard_score_sequential)


@keras_test
def test_weighted_metrics_with_sample_weight():
decimal = decimal_precision[K.backend()]

model = create_sequential_model()
model.compile(loss=loss, optimizer='rmsprop', metrics=[loss], weighted_metrics=[loss])

(x_train, y_train), (x_test, y_test), (sample_weight, class_weight, test_ids) = _get_test_data()

history = model.fit(x_train, y_train, batch_size=batch_size,
epochs=epochs // 3, verbose=0,
sample_weight=sample_weight)

h = history.history
assert_array_almost_equal(h['loss'], h['weighted_' + loss_full_name], decimal=decimal)

history = model.fit(x_train, y_train, batch_size=batch_size,
epochs=epochs // 3, verbose=0,
sample_weight=sample_weight,
validation_split=0.1)

h = history.history
assert_almost_equal(h['val_loss'], h['val_weighted_' + loss_full_name], decimal=decimal)

model.train_on_batch(x_train[:32], y_train[:32],
sample_weight=sample_weight[:32])
model.test_on_batch(x_train[:32], y_train[:32],
sample_weight=sample_weight[:32])

test_sample_weight = np.ones((y_test.shape[0])) * standard_weight
test_sample_weight[test_ids] = high_weight

scores = model.evaluate(x_test, y_test, verbose=0, sample_weight=test_sample_weight)
loss_score, metric_score, weighted_metric_score = scores

assert loss_score < standard_score_sequential
assert loss_score != metric_score
assert_almost_equal(loss_score, weighted_metric_score, decimal=decimal)


@keras_test
def test_weighted_metrics_with_no_sample_weight():
decimal = decimal_precision[K.backend()]

model = create_sequential_model()
model.compile(loss=loss, optimizer='rmsprop', metrics=[loss], weighted_metrics=[loss])

(x_train, y_train), (x_test, y_test), _ = _get_test_data()

history = model.fit(x_train, y_train, batch_size=batch_size,
epochs=epochs // 3, verbose=0)

h = history.history
assert_array_almost_equal(h['loss'], h[loss_full_name], decimal=decimal)
assert_array_almost_equal(h['loss'], h['weighted_' + loss_full_name], decimal=decimal)

history = model.fit(x_train, y_train, batch_size=batch_size,
epochs=epochs // 3, verbose=0, validation_split=0.1)

h = history.history
assert_array_almost_equal(h['val_loss'], h['val_' + loss_full_name], decimal=decimal)
assert_array_almost_equal(h['val_loss'], h['val_weighted_' + loss_full_name], decimal=decimal)

model.train_on_batch(x_train[:32], y_train[:32])
model.test_on_batch(x_train[:32], y_train[:32])

scores = model.evaluate(x_test, y_test, verbose=0)
loss_score, metric_score, weighted_metric_score = scores

assert_almost_equal(loss_score, metric_score, decimal=decimal)
assert_almost_equal(loss_score, weighted_metric_score, decimal=decimal)


@keras_test
def test_weighted_metrics_with_weighted_accuracy_metric():
model = create_sequential_model()
model.compile(loss=loss, optimizer='rmsprop', metrics=['acc'], weighted_metrics=['acc'])

(x_train, y_train), _, (sample_weight, _, _) = _get_test_data()

history = model.fit(x_train, y_train, batch_size=batch_size,
epochs=epochs // 3, verbose=0,
sample_weight=sample_weight)

assert history.history['acc'] != history.history['weighted_acc']


@keras_test
def test_weighted_metrics_with_multiple_outputs():
decimal = decimal_precision[K.backend()]

inputs = Input(shape=(5,))
x = Dense(5)(inputs)
output1 = Dense(1, name='output1')(x)
output2 = Dense(1, name='output2')(x)

model = Model(inputs=inputs, outputs=[output1, output2])

metrics = {'output1': [loss], 'output2': [loss]}
weighted_metrics = {'output2': [loss]}
loss_map = {'output1': loss, 'output2': loss}

model.compile(loss=loss_map, optimizer='sgd', metrics=metrics, weighted_metrics=weighted_metrics)

x = np.array([[1, 1, 1, 1, 1]])
y = {'output1': np.array([0]), 'output2': np.array([1])}
weight = 5

history = model.fit(x, y, sample_weight={'output2': np.array([weight])})

unweighted_metric = history.history['output2_' + loss_full_name][0]
weighted_metric = history.history['output2_weighted_' + loss_full_name][0]

assert_almost_equal(unweighted_metric * weight, weighted_metric, decimal=decimal)


@keras_test
def test_class_weight_wrong_classes():
model = create_sequential_model()
Expand Down