@@ -5567,8 +5567,12 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
5567
5567
labels = target , logits = output , axis = axis
5568
5568
)
5569
5569
5570
- # scale preds so that the class probas of each sample sum to 1
5570
+ # Adjust the predictions so that the probability of
5571
+ # each class for every sample adds up to 1
5572
+ # This is needed to ensure that the cross entropy is
5573
+ # computed correctly.
5571
5574
output = output / tf .reduce_sum (output , axis , True )
5575
+
5572
5576
# Compute cross entropy from probabilities.
5573
5577
epsilon_ = _constant_to_tensor (epsilon (), output .dtype .base_dtype )
5574
5578
output = tf .clip_by_value (output , epsilon_ , 1.0 - epsilon_ )
@@ -5845,28 +5849,29 @@ def binary_focal_crossentropy(
5845
5849
where `alpha` is a float in the range of `[0, 1]`.
5846
5850
5847
5851
Args:
5848
- target: A tensor with the same shape as `output`.
5849
- output: A tensor.
5850
- apply_class_balancing: A bool, whether to apply weight balancing on the
5851
- binary classes 0 and 1.
5852
- alpha: A weight balancing factor for class 1, default is `0.25` as
5853
- mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
5854
- gamma: A focusing parameter, default is `2.0` as mentioned in the
5855
- reference.
5856
- from_logits: Whether `output` is expected to be a logits tensor. By
5857
- default, we consider that `output` encodes a probability distribution.
5852
+ target: A tensor with the same shape as `output`.
5853
+ output: A tensor.
5854
+ apply_class_balancing: A bool, whether to apply weight balancing on the
5855
+ binary classes 0 and 1.
5856
+ alpha: A weight balancing factor for class 1, default is `0.25` as
5857
+ mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
5858
+ gamma: A focusing parameter, default is `2.0` as mentioned in the
5859
+ reference.
5860
+ from_logits: Whether `output` is expected to be a logits tensor. By
5861
+ default, we consider that `output` encodes a probability
5862
+ distribution.
5858
5863
5859
5864
Returns:
5860
- A tensor.
5865
+ A tensor.
5861
5866
"""
5862
- sigmoidal = tf .__internal__ .smart_cond .smart_cond (
5863
- from_logits ,
5864
- lambda : sigmoid (output ),
5865
- lambda : output ,
5866
- )
5867
+
5868
+ sigmoidal = sigmoid (output ) if from_logits else output
5869
+
5867
5870
p_t = target * sigmoidal + (1 - target ) * (1 - sigmoidal )
5871
+
5868
5872
# Calculate focal factor
5869
5873
focal_factor = tf .pow (1.0 - p_t , gamma )
5874
+
5870
5875
# Binary crossentropy
5871
5876
bce = binary_crossentropy (
5872
5877
target = target ,
@@ -5894,7 +5899,7 @@ def sigmoid(x):
5894
5899
Returns:
5895
5900
A tensor.
5896
5901
"""
5897
- return tf .sigmoid (x )
5902
+ return tf .math . sigmoid (x )
5898
5903
5899
5904
5900
5905
@keras_export ("keras.backend.hard_sigmoid" )
0 commit comments