Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge upstream #62

Merged
merged 237 commits into from
Apr 11, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
237 commits
Select commit Hold shift + click to select a range
7c84229
First pass.
fchollet Mar 5, 2019
17eab76
2nd pass.
fchollet Mar 5, 2019
bbba9ce
3rd pass
fchollet Mar 5, 2019
4d857be
4th pass
fchollet Mar 6, 2019
8cc7a97
Revert backend rnn.
fchollet Mar 6, 2019
c4d29cd
5th pass
fchollet Mar 6, 2019
0ca37c9
Quick fixes
fchollet Mar 7, 2019
c30223e
Simplification
fchollet Mar 8, 2019
f9e4f20
Update Travis config to test on TF 2
fchollet Mar 8, 2019
0a7be87
Fix some syntax error
fchollet Mar 8, 2019
255caad
Fix travis issues again
fchollet Mar 8, 2019
4b9d169
Fixes
fchollet Mar 8, 2019
ed387f1
Unit test fixes
fchollet Mar 8, 2019
820251a
Small fix.
fchollet Mar 12, 2019
c8b0e33
Tf 2: fix optimizer weights naming collision issue (#12466)
farizrahman4u Mar 12, 2019
7e5d34f
Merge branch 'tf-2' of github.com:keras-team/keras into tf-2
fchollet Mar 12, 2019
2fbcf34
Fix docstring of util function.
fchollet Mar 12, 2019
a41ccec
Fix docstring style
fchollet Mar 12, 2019
408a344
TF-2: Remove get_session() call in multi_gpu_utils.py (#12465)
farizrahman4u Mar 12, 2019
286a208
Merge branch 'tf-2' of github.com:keras-team/keras into tf-2
fchollet Mar 12, 2019
c6428ee
Fix in_top_k
fchollet Mar 12, 2019
7c5057b
Simplify bidirectional test
fchollet Mar 12, 2019
995f1e7
Move TensorBoard callback to v2 -- still need to fix some tests.
fchollet Mar 13, 2019
a7eff2f
Merge branch 'master' into tf-2
fchollet Mar 14, 2019
8240ef7
Small fixes.
fchollet Mar 14, 2019
5b653e2
Fix v1 tests.
fchollet Mar 14, 2019
ea1acc1
Merge branch 'master' into tf-2
fchollet Mar 15, 2019
331d5b0
Fix PEP8.
fchollet Mar 15, 2019
8e23a3e
Fix initializers and update ops.
fchollet Mar 22, 2019
b1fbc24
Merge branch 'master' into tf-2
fchollet Mar 22, 2019
05637cf
Disable test for TF print
fchollet Mar 22, 2019
d3ab82c
Fix gradient test failure.
fchollet Mar 22, 2019
e8ec47a
Fix test_model_with_external_loss
fchollet Mar 22, 2019
0cccb83
Small backend simplification
fchollet Mar 22, 2019
f28ef61
fix merge conflict
fchollet Mar 25, 2019
d5798e7
Merge master.
fchollet Apr 6, 2019
ed144aa
Fix convrnn tests.
fchollet Apr 6, 2019
e361a26
Fix identity init
fchollet Apr 6, 2019
8c02f4e
Remove irrelevant tests.
fchollet Apr 6, 2019
7d57c58
Fix conv2d_transpose
fchollet Apr 7, 2019
10970b9
Fix PEP8
fchollet Apr 7, 2019
32aae34
Merge branch 'master' into tf-2
fchollet Apr 12, 2019
d48c261
Merge branch 'master' into tf-2
fchollet Apr 17, 2019
747ef0f
Disable multiprocessing tests.
fchollet Apr 17, 2019
9c86747
Merge branch 'master' into tf-2
fchollet May 21, 2019
3b2e5e5
Fix tests.
fchollet May 21, 2019
510f9f0
Fix conv_rnn bug with cntk/theano
fchollet May 22, 2019
5430a7b
Fix TF1 test
fchollet May 23, 2019
0792332
Adding Loss, LossFunctionWrapper, MeanSquaredError classes. (#12859)
pavithrasv May 24, 2019
f870823
Merge branch 'master' into tf-2
fchollet May 24, 2019
3d48e27
Adding MeanAbsoluteError, MeanAbsolutePercentageError, MeanSquaredLog…
pavithrasv May 31, 2019
9551d2d
Update image preprocessing.
fchollet May 31, 2019
e3d20c6
Update applications.
fchollet May 31, 2019
99ebe77
Remove ResNeXt networks (bug) and add tests.
fchollet May 31, 2019
ab3ef6f
Adding CategoricalCrossentropy and SparseCategoricalCrossentropy Loss…
pavithrasv Jun 1, 2019
aa23003
Fix docstring
fchollet Jun 4, 2019
910e124
Adding support for `Loss` instances in model compile. (#12915)
pavithrasv Jun 6, 2019
cd3f27b
Ingore `xla_gpu` which is specially reserved by TF (#12928)
ghostplant Jun 7, 2019
850ffbf
Merge branch 'master' of github.com:keras-team/keras
fchollet Jun 7, 2019
f312e5a
Remove ignored applications files.
fchollet Jun 7, 2019
c658993
correct DepthwiseConv2D docstring (#12949)
adrianstaniec Jun 17, 2019
b810de6
Adding weighted_metrics to model loading/saving (#12984)
tallakahath Jun 23, 2019
58dd522
Merge branch 'master' of github.com:keras-team/keras
fchollet Jun 24, 2019
2d2fb47
Remove deprecated applications adapter code
fchollet Jun 24, 2019
dccb24a
Merge branch 'master' into tf-2
fchollet Jun 24, 2019
c714efa
Fix a number of tests.
fchollet Jun 24, 2019
1575b8c
Remove outdated test
fchollet Jun 24, 2019
5f37eaa
Fix PEP8
fchollet Jun 25, 2019
5949aee
Change defaults of GRU and LSTM layers.
fchollet Jun 25, 2019
36098bf
Rename lr to learning_rate in optimizers
fchollet Jun 25, 2019
ba78b1b
Add missing loss classes
fchollet Jun 26, 2019
979b636
Fix tests.
fchollet Jun 26, 2019
613aeff
fine tune sparse_categorical_crossentropy (#13010)
bojone Jun 27, 2019
442423e
Removing @symbolic from few tf backend ops.
pavithrasv Jul 1, 2019
aa04a83
Creating helper function for broadcast weights.
pavithrasv Jul 1, 2019
d01f8cb
Adding Metric class.
pavithrasv Jul 1, 2019
fd35671
Adding Metric class.
pavithrasv Jul 1, 2019
e73a8ce
Adding sample weight unit test.
pavithrasv Jul 1, 2019
7350d1e
Adding Mean metric class and unit tests.
pavithrasv Jul 1, 2019
e2c0da4
Addressed comments.
pavithrasv Jul 2, 2019
c93bd65
Adding control dependency op.
pavithrasv Jul 2, 2019
62c6b77
Adding no-op control dependencies to Theano and CNTK backends.
pavithrasv Jul 2, 2019
3e77ab6
Attempt to fix tests.
fchollet Jul 3, 2019
7de4a8b
Merge branch 'master' of github.com:keras-team/keras
fchollet Jul 3, 2019
17f6814
Sync up with keras-applications (#13044)
taehoonlee Jul 3, 2019
9e99fcf
Fixing tests for TF1.
pavithrasv Jul 3, 2019
0cdff54
Fix doc string.
pavithrasv Jul 3, 2019
72b55d2
Remove outdated integration test.
fchollet Jul 3, 2019
d952f0c
Merge branch 'master' of github.com:keras-team/keras
fchollet Jul 3, 2019
f06524c
load_weights() now properly closes file (#13048)
bthorsted Jul 3, 2019
0fc33fe
Fixed typo (#13060)
HarikrishnanBalagopal Jul 3, 2019
b365dcf
Fix backend issues.
fchollet Jul 3, 2019
85e530f
Adding MeanMetricWrapper class.
pavithrasv Jul 6, 2019
3e8a7ca
Adding MeanSquaredError metric.
pavithrasv Jul 6, 2019
87b2b00
Framework changes for metrics part 1
pavithrasv Jul 6, 2019
b78e3ff
Update conv_filter_visualization.py (#13032)
abhinavsagar Jul 7, 2019
b938b91
Metrics framework changes part 2
pavithrasv Jul 8, 2019
72a2aa5
Adding metrics correctness test.
pavithrasv Jul 8, 2019
3bda552
Update pretrained_word_embeddings.py (#13073)
abhipn Jul 9, 2019
ebe3f30
Fix integration tests
fchollet Jul 9, 2019
ed07472
Allow unrolled RNNs with input_length=1 (#13078)
farizrahman4u Jul 10, 2019
efe72ef
For better performance (#13144)
Neutron3529 Jul 23, 2019
c10d249
RNN initial state: bug fix + suppress false warning (#13138)
farizrahman4u Jul 25, 2019
4385de6
Remove references to ResNeXt from docs.
fchollet Aug 21, 2019
aa28910
Prepare 2.2.5 release.
fchollet Aug 22, 2019
c074416
Fix sklearn wrapper unit test in Python 3?
fchollet Aug 22, 2019
fb7f49e
Fix sklearn regressor test?
fchollet Aug 22, 2019
790c74c
Some Theano fixes.
fchollet Aug 23, 2019
8feac19
Make metrics compatible with Theano
fchollet Aug 24, 2019
caceebc
Theano fixes
fchollet Aug 25, 2019
2ccb69d
Fix
fchollet Aug 25, 2019
21efc67
Fixes
fchollet Aug 25, 2019
ee3997f
Recompute steps_per_epoch after each epoch in traingin_generator (#13…
vikua Aug 25, 2019
a39f10a
pep8 config in setup.cfg (#13196)
PhilipMay Aug 25, 2019
387aea3
Theano fixes
fchollet Aug 25, 2019
5446255
Update optimizers for TF2. (#13246)
tanzhenyu Aug 26, 2019
f6bdacd
Fix results tracking for metrics in multi-output case
fchollet Aug 26, 2019
2bb96b6
sync changes to _TfDeviceCaptureOp (#13255)
Aug 28, 2019
61052bc
Documentation for `array_to_img`, `img_to_array` and `save_img` under…
mathemage Aug 28, 2019
a47f5e2
Add metric API changes (#13256)
pavithrasv Aug 28, 2019
d9fae78
Fix metrics support in Theano
fchollet Aug 28, 2019
3625bf4
Fix PEP8
fchollet Aug 28, 2019
555ca94
Introduces fixes for tensor equality in TF 2.0
fchollet Aug 28, 2019
200358f
Minor fixes
fchollet Aug 28, 2019
18c11a3
Improve exception testing in test_training
fchollet Aug 28, 2019
a996d3d
Improve test syntax
fchollet Aug 28, 2019
cca4925
Merge branch 'master' into tf-2
fchollet Aug 28, 2019
d36c7b2
Remove outdated tests
fchollet Aug 28, 2019
255a1ac
Only run label smoothing logic when necessary
fchollet Aug 28, 2019
c2b4d8f
Fix PEP8
fchollet Aug 28, 2019
17082f6
Update CI to run on TF 1.14 for TF1
fchollet Aug 28, 2019
10c8e50
Disable a backend test for CNTK
fchollet Aug 28, 2019
d4fe07c
Fix docs test
fchollet Aug 28, 2019
0407495
CNTK fixes
fchollet Aug 28, 2019
91658c5
Disable test that hangs Travis
fchollet Aug 29, 2019
67fb3de
Disabled flaky cntk test
fchollet Aug 29, 2019
bda6bb1
Reduce test flakiness
fchollet Aug 29, 2019
b10f5ca
Disable CNTK SGD test
fchollet Aug 29, 2019
4981b70
Disable test causing Travis to hang
fchollet Aug 29, 2019
daebe7c
Disable flaky CNTK test
fchollet Aug 29, 2019
a69eaa8
Disable test that hangs Travis
fchollet Aug 29, 2019
519fac7
Disable a couple more multiprocessing tests
fchollet Aug 29, 2019
3b853b9
Add ability for Layer to track sublayer attributes
fchollet Aug 29, 2019
479fc3a
Add support for layer attribute tracking (loss, updates, metrics) in …
fchollet Aug 30, 2019
655cfd3
Fix theano backend
fchollet Aug 30, 2019
a82a7e9
Fix PEP8
fchollet Aug 30, 2019
969db5a
Adding accuracy metric classes. (#13265)
pavithrasv Aug 30, 2019
9cfc0d1
Add metrics Hinge, SquaredHinge, CategoricalHinge
fchollet Aug 30, 2019
0fe44e3
Merge
fchollet Aug 30, 2019
73106d5
Add label conversion to hinge losses
fchollet Aug 30, 2019
680be2e
Adding LogCosh, Poisson, KLDivergence, crossentropy metrics. (#13271)
pavithrasv Aug 30, 2019
38e3831
Add metrics CosineSimilarity, MeanAbsoluteError, MeanAbsolutePercenta…
fchollet Aug 30, 2019
d830024
Reverse sign of cosine_similarity metric
fchollet Aug 30, 2019
088bda5
Adding TruePositives, TrueNegatives, FalsePositives, FalseNegatives m…
pavithrasv Sep 4, 2019
63c0369
Adding AUC, SensitivityAtSpecificity metrics. (#13289)
pavithrasv Sep 5, 2019
5dc27d0
Add SpecificityAtSensitivity metric. (#13294)
pavithrasv Sep 6, 2019
ab124c3
Adding Precision, Recall, Mean IoU part 1.
pavithrasv Sep 6, 2019
dcb7a16
Adding Precision, Recall, Mean IoU part 1.
pavithrasv Sep 6, 2019
de27619
Add MeanIoU metric.
pavithrasv Sep 7, 2019
4de37bc
Add MeanIoU metric.
pavithrasv Sep 7, 2019
8de897a
Fix metrics reporting / accumulation with fit_generator and evaluate…
fchollet Sep 8, 2019
dd69e39
Merge branch 'tf-2' of github.com:keras-team/keras into tf-2
fchollet Sep 8, 2019
2bc43bf
Remove deprecated example script
fchollet Sep 8, 2019
d8446ef
Remove deprecated example, fix conv filter example
fchollet Sep 8, 2019
e6d685b
Update examples
fchollet Sep 8, 2019
6bb3fb7
Fix some bugs
fchollet Sep 8, 2019
763f69f
Update examples
fchollet Sep 9, 2019
69fcefd
Addressed PR comments.
pavithrasv Sep 9, 2019
bb9293c
Merge branch 'tf-2' of https://github.com/pavithrasv/keras into pavit…
fchollet Sep 9, 2019
388cbf1
Fix Theano tests.
fchollet Sep 9, 2019
922dd77
Merge branch 'pavithrasv-tf-2' into tf-2
fchollet Sep 9, 2019
a1be7c3
Disable top_k metrics for TF1
fchollet Sep 9, 2019
fe4110f
Reenable Precision and Recall with TF1
fchollet Sep 9, 2019
b86a986
Fix Py2 tests
fchollet Sep 9, 2019
3e70caf
Skip metric tests for CNTK
fchollet Sep 9, 2019
da289bb
Fix py2 test
fchollet Sep 9, 2019
bbed5cf
Fix py2 test
fchollet Sep 9, 2019
a8478c3
Update coverage threshold
fchollet Sep 9, 2019
1cf5218
Merge branch 'master' of github.com:keras-team/keras
fchollet Sep 9, 2019
88ca804
Add back CPU to multi_gpu_utils available devices
fchollet Sep 10, 2019
3e8e273
Using K.is_tensor and K.is_variable (#13307)
shoeffner Sep 11, 2019
8315a0b
Update lstm_seq2seq.py(from 22% to 87% acc) (#13269)
tykimos Sep 11, 2019
6645537
Update babi_rnn.py (#13263)
BluFalcon Sep 11, 2019
280e5b7
typo fixed (#13230)
ArnoutDevos Sep 11, 2019
ccecd39
Correct the DepthwiseConv2d docstrings - output shape (#13225)
keunwoochoi Sep 11, 2019
1eac861
fix in "Layer.compute_output_shape" description (#13210)
Inkln Sep 11, 2019
cb96315
Added batch_normalization in the numpy backend. (#11556)
gabrieldemarmiesse Sep 11, 2019
93b0f1c
Complete the docs by adding data to multi-input/output example (#12775)
bharatr21 Sep 11, 2019
033983d
Fix Travis SSL issue.
fchollet Sep 11, 2019
7183813
Merge branch 'master' of github.com:keras-team/keras
fchollet Sep 11, 2019
d3512f7
#13239 Improved documentation for EarlyStopping/ReduceLROnPlateau, ta…
hendriks73 Sep 12, 2019
7869134
Added messages about the future of multi-backend Keras. (#13315)
gabrieldemarmiesse Sep 15, 2019
cf9595a
Fix sequence timeout deadlock (#13322)
andreyz4k Sep 15, 2019
9080613
Fix deprecation warnings related to TF v1
fchollet Sep 16, 2019
a0335a3
Update README
fchollet Sep 17, 2019
8a8ef43
Add a link to the metrics document (#13334)
Sep 18, 2019
977d55c
Fix thread safety issue
fchollet Sep 19, 2019
95fab0e
Merge branch 'master' of github.com:keras-team/keras
fchollet Sep 19, 2019
4f94a15
Correct spelling mistake (#13339)
shivdhar Sep 20, 2019
04cbccc
fix #13341 math_ops to K (#13342)
djstrong Sep 20, 2019
9ad5a18
Fix encoding error (#13355)
fuzzythecat Sep 25, 2019
f2bbf98
Fix issue where the disable_tracking decorator obfuscates layer const…
fchollet Sep 25, 2019
5be4ed3
Fix yaml version compat issue
fchollet Sep 25, 2019
985521e
Update local.py docstrings (#13373)
Naruu Sep 29, 2019
f0464c9
Allowed to return the image as a Jupyter Image only if the extension …
ftesser Oct 1, 2019
16bd239
Fix file leak in CSVLogger (#13378)
GregoryMorse Oct 4, 2019
481e99d
fix: `recurrent_activation` parameter's docstring (#13401)
ndrwnaguib Oct 6, 2019
8904340
typo_fix (#13395)
haifeng-jin Oct 6, 2019
f295e8e
Prepare 2.3.1 release
fchollet Oct 7, 2019
c8f66d1
Added the default activation of convolutional LSTM in the docs. (#13409)
MichelleVivita Oct 8, 2019
b75b2f7
Small refactors on the keras.utils module (#13388)
eltonvs Oct 8, 2019
97e3916
Bumped tf2 version to 2.0.0 (#13412)
gabrieldemarmiesse Oct 9, 2019
2b1f8ed
Change `batch_size` descriptions to proper ones (#13422)
EthanJYK Oct 11, 2019
f242c64
Update autogen.py (#13426)
Naruu Oct 13, 2019
e25f68d
Update io_utils.py (#13429)
Denny-Hwang Oct 18, 2019
afff7b4
Update pooling.py (#13467)
Naruu Oct 21, 2019
4d59675
Update core.py (#13472)
Naruu Oct 21, 2019
ecac367
Fix h5py group naming while model saving (#13477)
Tbuhet Oct 22, 2019
e8946d5
Update np_utils.py (#13481)
Denny-Hwang Oct 24, 2019
7a39b6c
Fix too many values to unpack error (#13511)
xemcerk Nov 6, 2019
94a69c2
Merge remote-tracking branch 'upstream/master'
Feb 20, 2020
de4926f
Update tf backend
Feb 20, 2020
003a294
Update tf backend
Feb 20, 2020
ad9d4de
Update tf backend
Feb 20, 2020
4338492
fix
Feb 20, 2020
83501a0
fix
Feb 20, 2020
252d094
fix
Feb 20, 2020
9ad8e80
fix
Feb 24, 2020
24e3fcb
fix
Feb 24, 2020
9758174
Bump TF and Python versions
Feb 25, 2020
f4c0468
Update dist
Feb 25, 2020
170494a
Update tf_backend
Feb 25, 2020
b97d26d
pep8
Feb 25, 2020
799bf4c
fix
Feb 25, 2020
258fea5
Fix cudnnRNN--RNN weight loading
Apr 11, 2020
22bf5fb
Update config
Apr 11, 2020
45e0ac5
Update .travis
Apr 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding Loss, LossFunctionWrapper, MeanSquaredError classes. (keras-te…
…am#12859)

* Adding Loss, LossFunctionWrapper, MeanSquaredError classes.

* Fixing formatting issues.

* Adding arguments list to MeanSquaredError.

* Fix abstract method
  • Loading branch information
pavithrasv authored and fchollet committed May 24, 2019
commit 0792332f77e6a34093cab9c2bc6f638d1676bccd
1 change: 1 addition & 0 deletions keras/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
from .load_backend import name_scope
from .load_backend import symbolic
from .load_backend import eager
from .load_backend import size

if backend() == 'theano':
from .load_backend import pattern_broadcast
Expand Down
4 changes: 4 additions & 0 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,10 @@ def cast(x, dtype):
return x


def size(x, name=None):
return sum(ones_like(x, name=name))


def dot(x, y):
if len(x.shape) > 2 or len(y.shape) > 2:
y_shape = int_shape(y)
Expand Down
23 changes: 23 additions & 0 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,29 @@ def ndim(x):
return x.shape.rank


def size(x, name=None):
"""Returns the size of a tensor.

# Arguments
x: Tensor or variable.
name: A name for the operation (optional).

# Returns
Size of the tensor.

# Examples
```python
>>> from keras import backend as K
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> K.size(inputs)
<tf.Tensor: id=9, shape=(), dtype=int32, numpy=4>
```

"""
return tf.size(x, name=name)


def dtype(x):
"""Returns the dtype of a Keras tensor or variable, as a string.

Expand Down
12 changes: 12 additions & 0 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,18 @@ def count_params(x):

def cast(x, dtype):
return T.cast(x, dtype)


def size(x, name=None):
"""Returns the size of a tensor.
# Arguments
x: The input tensor.
name: A name for the operation (optional).
# Returns
Size of the tensor.
```
"""
return sum(ones_like(x, name=name))


# UPDATES OPS
Expand Down
162 changes: 162 additions & 0 deletions keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,174 @@
from __future__ import division
from __future__ import print_function

import abc
import six

from . import backend as K
from .utils import losses_utils
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object


@six.add_metaclass(abc.ABCMeta)
class Loss(object):
"""Loss base class.

To be implemented by subclasses:
* `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`.

Example subclass implementation:
```python
class MeanSquaredError(Loss):
def call(self, y_true, y_pred):
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
return K.mean(math_ops.square(y_pred - y_true), axis=-1)
```

# Arguments
reduction: (Optional) Type of loss Reduction to apply to loss.
Default value is `SUM_OVER_BATCH_SIZE`.
name: Optional name for the op.
"""

def __init__(self,
reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE,
name=None):
self.reduction = reduction
self.name = name

def __call__(self, y_true, y_pred, sample_weight=None):
"""Invokes the `Loss` instance.

# Arguments
y_true: Ground truth values.
y_pred: The predicted values.
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
as `y_true`, or is broadcastable to `y_true`. `sample_weight` acts as a
coefficient for the loss. If a scalar is provided, then the loss is
simply scaled by the given value. If `sample_weight` is a tensor of size
`[batch_size]`, then the total loss for each sample of the batch is
rescaled by the corresponding element in the `sample_weight` vector. If
the shape of `sample_weight` matches the shape of `y_pred`, then the
loss of each measurable element of `y_pred` is scaled by the
corresponding value of `sample_weight`.

# Returns
Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
shape as `y_true`; otherwise, it is scalar.

# Raises
ValueError: If the shape of `sample_weight` is invalid.
"""
# If we are wrapping a lambda function strip '<>' from the name as it is not
# accepted in scope name.
scope_name = 'lambda' if self.name == '<lambda>' else self.name
with K.name_scope(scope_name):
losses = self.call(y_true, y_pred)
return losses_utils.compute_weighted_loss(
losses, sample_weight, reduction=self.reduction)

@classmethod
def from_config(cls, config):
"""Instantiates a `Loss` from its config (output of `get_config()`).

# Arguments
config: Output of `get_config()`.

# Returns
A `Loss` instance.
"""
return cls(**config)

def get_config(self):
return {'reduction': self.reduction, 'name': self.name}

@abc.abstractmethod
def call(self, y_true, y_pred):
"""Invokes the `Loss` instance.

# Arguments
y_true: Ground truth values, with the same shape as 'y_pred'.
y_pred: The predicted values.
"""
raise NotImplementedError('Must be implemented in subclasses.')


class LossFunctionWrapper(Loss):
"""Wraps a loss function in the `Loss` class.

# Arguments
fn: The loss function to wrap, with signature `fn(y_true, y_pred,
**kwargs)`.
reduction: (Optional) Type of loss reduction to apply to loss.
Default value is `SUM_OVER_BATCH_SIZE`.
name: (Optional) name for the loss.
**kwargs: The keyword arguments that are passed on to `fn`.
"""

def __init__(self,
fn,
reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE,
name=None,
**kwargs):
super(LossFunctionWrapper, self).__init__(reduction=reduction, name=name)
self.fn = fn
self._fn_kwargs = kwargs

def call(self, y_true, y_pred):
"""Invokes the `LossFunctionWrapper` instance.

# Arguments
y_true: Ground truth values.
y_pred: The predicted values.

# Returns
Loss values per sample.
"""
return self.fn(y_true, y_pred, **self._fn_kwargs)

def get_config(self):
config = {}
for k, v in six.iteritems(self._fn_kwargs):
config[k] = K.eval(v) if is_tensor_or_variable(v) else v
base_config = super(LossFunctionWrapper, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class MeanSquaredError(LossFunctionWrapper):
"""Computes the mean of squares of errors between labels and predictions.

For example, if `y_true` is [0., 0., 1., 1.] and `y_pred` is [1., 1., 1., 0.]
then the mean squared error value is 3/4 (0.75).

Standalone usage:

```python
mse = keras.losses.MeanSquaredError()
loss = mse([0., 0., 1., 1.], [1., 1., 1., 0.])
```

Usage with the `compile` API:

```python
model = keras.Model(inputs, outputs)
model.compile('sgd', loss=keras.losses.MeanSquaredError())
```

# Arguments
reduction: (Optional) Type of loss reduction to apply to loss.
Default value is `SUM_OVER_BATCH_SIZE`.
name: (Optional) name for the loss.
"""

def __init__(self,
reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE,
name='mean_squared_error'):
super(MeanSquaredError, self).__init__(
mean_squared_error, name=name, reduction=reduction)


def mean_squared_error(y_true, y_pred):
return K.mean(K.square(y_pred - y_true), axis=-1)

Expand Down
1 change: 1 addition & 0 deletions keras/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import data_utils
from . import io_utils
from . import conv_utils
from . import losses_utils

# Globally-importable utils.
from .io_utils import HDF5Matrix
Expand Down
136 changes: 136 additions & 0 deletions keras/utils/losses_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Utilities related to losses."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from .. import backend as K


class Reduction(object):
"""Types of loss reduction.

Contains the following values:

* `NONE`: Un-reduced weighted losses with the same shape as input. When this
reduction type used with built-in Keras training loops like
`fit`/`evaluate`, the unreduced vector loss is passed to the optimizer but
the reported loss will be a scalar value.
* `SUM`: Scalar sum of weighted losses.
* `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
"""

NONE = 'none'
SUM = 'sum'
SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'

@classmethod
def all(cls):
return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)

@classmethod
def validate(cls, key):
if key not in cls.all():
raise ValueError('Invalid Reduction Key %s.' % key)


def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
"""Squeeze or expand last dimension if needed.

1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1.
2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
from the new rank of `y_pred`.
If `sample_weight` is scalar, it is kept scalar.

# Arguments
y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
y_true: Optional label `Tensor` whose dimensions match `y_pred`.
sample_weight: Optional weight scalar or `Tensor` whose dimensions match
`y_pred`.

# Returns
Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
the last dimension squeezed, `sample_weight` could be extended by one
dimension.
"""
if y_true is not None:
y_pred_rank = K.ndim(y_pred)
y_pred_shape = K.int_shape(y_pred)
y_true_rank = K.ndim(y_true)
y_true_shape = K.int_shape(y_true)

if (y_pred_rank - y_true_rank == 1) and (y_pred_shape[-1] == 1):
y_pred = K.squeeze(y_pred, -1)
elif (y_true_rank - y_pred_rank == 1) and (y_true_shape[-1] == 1):
y_true = K.squeeze(y_true, -1)

if sample_weight is None:
return y_pred, y_true, None

y_pred_rank = K.ndim(y_pred)
weights_rank = K.ndim(sample_weight)
if weights_rank != 0:
if weights_rank - y_pred_rank == 1:
sample_weight = K.squeeze(sample_weight, -1)
elif y_pred_rank - weights_rank == 1:
sample_weight = K.expand_dims(sample_weight, -1)
return y_pred, y_true, sample_weight


def _num_elements(losses):
"""Computes the number of elements in `losses` tensor."""
with K.name_scope('num_elements') as scope:
return K.cast(K.size(losses, name=scope), losses.dtype)


def reduce_weighted_loss(weighted_losses, reduction=Reduction.SUM_OVER_BATCH_SIZE):
"""Reduces the individual weighted loss measurements."""
if reduction == Reduction.NONE:
loss = weighted_losses
else:
loss = K.sum(weighted_losses)
if reduction == Reduction.SUM_OVER_BATCH_SIZE:
loss = loss / _num_elements(weighted_losses)
return loss


def compute_weighted_loss(losses,
sample_weight=None,
reduction=Reduction.SUM_OVER_BATCH_SIZE,
name=None):
"""Computes the weighted loss.

# Arguments
losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
` losses`, or be broadcastable to `losses`.
reduction: (Optional) Type of Reduction to apply to loss.
Default value is `SUM_OVER_BATCH_SIZE`.
name: Optional name for the op.

# Raises
ValueError: If the shape of `sample_weight` is not compatible with `losses`.

# Returns
Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
`NONE`, this has the same shape as `losses`; otherwise, it is scalar.
"""
Reduction.validate(reduction)
if sample_weight is None:
sample_weight = 1.0
with K.name_scope(name or 'weighted_loss'):
input_dtype = K.dtype(losses)
losses = K.cast(losses, K.floatx())
sample_weight = K.cast(sample_weight, K.floatx())

# Update dimensions of `sample_weight` to match with `losses` if possible.
losses, _, sample_weight = squeeze_or_expand_dimensions(
losses, None, sample_weight)

weighted_losses = losses * sample_weight
# Apply reduction function to the individual weighted losses.
loss = reduce_weighted_loss(weighted_losses, reduction)
# Convert the result back to the input type.
loss = K.cast(loss, input_dtype)
return loss
Loading