Skip to content

Commit ef0a4cf

Browse files
PR #17746: Minor improvements and code refactoring in backend.py
Imported from GitHub PR #17746 Small changes in backend.py, some of were discussed in the PR #17651 Copybara import of the project: -- 0f89165 by Kaan Bıçakcı <kaan.dvlpr@gmail.com>: Small fixes on focal losses and cat.crossentropy -- 3c193de by Kaan Bıçakcı <kaan.dvlpr@gmail.com>: Fix linting and sigmoid func -- b87b656 by Kaan Bıçakcı <kaan.dvlpr@gmail.com>: Revert the redirection of the internal function Merging this change closes #17746 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17746 from Frightera:frightera_small_loss_fixes b87b656 PiperOrigin-RevId: 522179031
1 parent db138de commit ef0a4cf

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

keras/backend.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5567,8 +5567,12 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
55675567
labels=target, logits=output, axis=axis
55685568
)
55695569

5570-
# scale preds so that the class probas of each sample sum to 1
5570+
# Adjust the predictions so that the probability of
5571+
# each class for every sample adds up to 1
5572+
# This is needed to ensure that the cross entropy is
5573+
# computed correctly.
55715574
output = output / tf.reduce_sum(output, axis, True)
5575+
55725576
# Compute cross entropy from probabilities.
55735577
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
55745578
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
@@ -5845,28 +5849,29 @@ def binary_focal_crossentropy(
58455849
where `alpha` is a float in the range of `[0, 1]`.
58465850
58475851
Args:
5848-
target: A tensor with the same shape as `output`.
5849-
output: A tensor.
5850-
apply_class_balancing: A bool, whether to apply weight balancing on the
5851-
binary classes 0 and 1.
5852-
alpha: A weight balancing factor for class 1, default is `0.25` as
5853-
mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
5854-
gamma: A focusing parameter, default is `2.0` as mentioned in the
5855-
reference.
5856-
from_logits: Whether `output` is expected to be a logits tensor. By
5857-
default, we consider that `output` encodes a probability distribution.
5852+
target: A tensor with the same shape as `output`.
5853+
output: A tensor.
5854+
apply_class_balancing: A bool, whether to apply weight balancing on the
5855+
binary classes 0 and 1.
5856+
alpha: A weight balancing factor for class 1, default is `0.25` as
5857+
mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
5858+
gamma: A focusing parameter, default is `2.0` as mentioned in the
5859+
reference.
5860+
from_logits: Whether `output` is expected to be a logits tensor. By
5861+
default, we consider that `output` encodes a probability
5862+
distribution.
58585863
58595864
Returns:
5860-
A tensor.
5865+
A tensor.
58615866
"""
5862-
sigmoidal = tf.__internal__.smart_cond.smart_cond(
5863-
from_logits,
5864-
lambda: sigmoid(output),
5865-
lambda: output,
5866-
)
5867+
5868+
sigmoidal = sigmoid(output) if from_logits else output
5869+
58675870
p_t = target * sigmoidal + (1 - target) * (1 - sigmoidal)
5871+
58685872
# Calculate focal factor
58695873
focal_factor = tf.pow(1.0 - p_t, gamma)
5874+
58705875
# Binary crossentropy
58715876
bce = binary_crossentropy(
58725877
target=target,
@@ -5894,7 +5899,7 @@ def sigmoid(x):
58945899
Returns:
58955900
A tensor.
58965901
"""
5897-
return tf.sigmoid(x)
5902+
return tf.math.sigmoid(x)
58985903

58995904

59005905
@keras_export("keras.backend.hard_sigmoid")

0 commit comments

Comments
 (0)