Skip to content

Commit 0f89165

Browse files
committed
Small fixes on focal losses and cat.crossentropy
1 parent 0f8e81f commit 0f89165

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

keras/backend.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5566,8 +5566,12 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
55665566
labels=target, logits=output, axis=axis
55675567
)
55685568

5569-
# scale preds so that the class probas of each sample sum to 1
5569+
# Adjust the predictions so that the probability of
5570+
# each class for every sample adds up to 1
5571+
# This is needed to ensure that the cross entropy is
5572+
# computed correctly.
55705573
output = output / tf.reduce_sum(output, axis, True)
5574+
55715575
# Compute cross entropy from probabilities.
55725576
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
55735577
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
@@ -5647,7 +5651,7 @@ def categorical_focal_crossentropy(
56475651
)
56485652

56495653
if from_logits:
5650-
output = tf.nn.softmax(output, axis=axis)
5654+
output = softmax(output, axis=axis)
56515655

56525656
# Adjust the predictions so that the probability of
56535657
# each class for every sample adds up to 1
@@ -5844,28 +5848,28 @@ def binary_focal_crossentropy(
58445848
where `alpha` is a float in the range of `[0, 1]`.
58455849
58465850
Args:
5847-
target: A tensor with the same shape as `output`.
5848-
output: A tensor.
5849-
apply_class_balancing: A bool, whether to apply weight balancing on the
5850-
binary classes 0 and 1.
5851-
alpha: A weight balancing factor for class 1, default is `0.25` as
5852-
mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
5853-
gamma: A focusing parameter, default is `2.0` as mentioned in the
5854-
reference.
5855-
from_logits: Whether `output` is expected to be a logits tensor. By
5856-
default, we consider that `output` encodes a probability distribution.
5851+
target: A tensor with the same shape as `output`.
5852+
output: A tensor.
5853+
apply_class_balancing: A bool, whether to apply weight balancing on the
5854+
binary classes 0 and 1.
5855+
alpha: A weight balancing factor for class 1, default is `0.25` as
5856+
mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
5857+
gamma: A focusing parameter, default is `2.0` as mentioned in the
5858+
reference.
5859+
from_logits: Whether `output` is expected to be a logits tensor. By
5860+
default, we consider that `output` encodes a probability distribution.
58575861
58585862
Returns:
5859-
A tensor.
5863+
A tensor.
58605864
"""
5861-
sigmoidal = tf.__internal__.smart_cond.smart_cond(
5862-
from_logits,
5863-
lambda: sigmoid(output),
5864-
lambda: output,
5865-
)
5865+
5866+
sigmoidal = sigmoid(output) if from_logits else output
5867+
58665868
p_t = target * sigmoidal + (1 - target) * (1 - sigmoidal)
5869+
58675870
# Calculate focal factor
58685871
focal_factor = tf.pow(1.0 - p_t, gamma)
5872+
58695873
# Binary crossentropy
58705874
bce = binary_crossentropy(
58715875
target=target,

0 commit comments

Comments
 (0)