@@ -5566,8 +5566,12 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
5566
5566
labels = target , logits = output , axis = axis
5567
5567
)
5568
5568
5569
- # scale preds so that the class probas of each sample sum to 1
5569
+ # Adjust the predictions so that the probability of
5570
+ # each class for every sample adds up to 1
5571
+ # This is needed to ensure that the cross entropy is
5572
+ # computed correctly.
5570
5573
output = output / tf .reduce_sum (output , axis , True )
5574
+
5571
5575
# Compute cross entropy from probabilities.
5572
5576
epsilon_ = _constant_to_tensor (epsilon (), output .dtype .base_dtype )
5573
5577
output = tf .clip_by_value (output , epsilon_ , 1.0 - epsilon_ )
@@ -5647,7 +5651,7 @@ def categorical_focal_crossentropy(
5647
5651
)
5648
5652
5649
5653
if from_logits :
5650
- output = tf . nn . softmax (output , axis = axis )
5654
+ output = softmax (output , axis = axis )
5651
5655
5652
5656
# Adjust the predictions so that the probability of
5653
5657
# each class for every sample adds up to 1
@@ -5844,28 +5848,28 @@ def binary_focal_crossentropy(
5844
5848
where `alpha` is a float in the range of `[0, 1]`.
5845
5849
5846
5850
Args:
5847
- target: A tensor with the same shape as `output`.
5848
- output: A tensor.
5849
- apply_class_balancing: A bool, whether to apply weight balancing on the
5850
- binary classes 0 and 1.
5851
- alpha: A weight balancing factor for class 1, default is `0.25` as
5852
- mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
5853
- gamma: A focusing parameter, default is `2.0` as mentioned in the
5854
- reference.
5855
- from_logits: Whether `output` is expected to be a logits tensor. By
5856
- default, we consider that `output` encodes a probability distribution.
5851
+ target: A tensor with the same shape as `output`.
5852
+ output: A tensor.
5853
+ apply_class_balancing: A bool, whether to apply weight balancing on the
5854
+ binary classes 0 and 1.
5855
+ alpha: A weight balancing factor for class 1, default is `0.25` as
5856
+ mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
5857
+ gamma: A focusing parameter, default is `2.0` as mentioned in the
5858
+ reference.
5859
+ from_logits: Whether `output` is expected to be a logits tensor. By
5860
+ default, we consider that `output` encodes a probability distribution.
5857
5861
5858
5862
Returns:
5859
- A tensor.
5863
+ A tensor.
5860
5864
"""
5861
- sigmoidal = tf .__internal__ .smart_cond .smart_cond (
5862
- from_logits ,
5863
- lambda : sigmoid (output ),
5864
- lambda : output ,
5865
- )
5865
+
5866
+ sigmoidal = sigmoid (output ) if from_logits else output
5867
+
5866
5868
p_t = target * sigmoidal + (1 - target ) * (1 - sigmoidal )
5869
+
5867
5870
# Calculate focal factor
5868
5871
focal_factor = tf .pow (1.0 - p_t , gamma )
5872
+
5869
5873
# Binary crossentropy
5870
5874
bce = binary_crossentropy (
5871
5875
target = target ,
0 commit comments