Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,9 @@ def divide_no_nan(x1, x2):
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.where(x2 == 0, 0, np.divide(x1, x2))
# No need for the double-where trick since we don't calculate gradients in
# numpy backend.
return np.where(x2 == 0, np.array(0, dtype=dtype), np.divide(x1, x2))


def true_divide(x1, x2):
Expand Down
42 changes: 29 additions & 13 deletions keras/src/losses/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
class Loss(KerasSaveable):
"""Loss base class.

This is the class to subclass in order to create new
custom losses.
This is the class to subclass in order to create new custom losses.

Args:
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Supported options are `"sum"`, `"sum_over_batch_size"`, `"mean"`
or `None`.
this should be `"sum_over_batch_size"`. Supported options are
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
sample size, and `"mean_with_sample_weight"` sums the loss and
divides by the sum of the sample weights. `"none"` and `None`
perform no aggregation. Defaults to `"sum_over_batch_size"`.
name: Optional name for the loss instance.
dtype: The dtype of the loss's computations. Defaults to `None`, which
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
Expand Down Expand Up @@ -96,7 +99,14 @@ def _obj_type(self):


def standardize_reduction(reduction):
allowed = {"sum_over_batch_size", "sum", None, "none", "mean"}
allowed = {
"sum_over_batch_size",
"sum",
None,
"none",
"mean",
"mean_with_sample_weight",
}
if reduction not in allowed:
raise ValueError(
"Invalid value for argument `reduction`. "
Expand Down Expand Up @@ -127,7 +137,7 @@ def squeeze_or_expand_to_same_rank(x1, x2, expand_rank_1=True):
return x1, x2


def reduce_values(values, reduction="sum_over_batch_size"):
def reduce_values(values, sample_weight=None, reduction="sum_over_batch_size"):
if (
reduction is None
or reduction == "none"
Expand All @@ -136,11 +146,17 @@ def reduce_values(values, reduction="sum_over_batch_size"):
):
return values
loss = ops.sum(values)
if reduction in ("mean", "sum_over_batch_size"):
loss /= ops.cast(
ops.prod(ops.convert_to_tensor(ops.shape(values), dtype="int32")),
loss.dtype,
)
if reduction in ("sum_over_batch_size", "mean", "mean_with_sample_weight"):
if reduction == "mean_with_sample_weight" and sample_weight is not None:
divisor = ops.cast(ops.sum(sample_weight), loss.dtype)
else:
divisor = ops.cast(
ops.prod(
ops.convert_to_tensor(ops.shape(values), dtype="int32")
),
loss.dtype,
)
loss = ops.divide_no_nan(loss, divisor)
return loss


Expand Down Expand Up @@ -173,7 +189,7 @@ def reduce_weighted_values(
values = values * sample_weight

# Apply reduction function to the individual weighted losses.
loss = reduce_values(values, reduction)
loss = reduce_values(values, sample_weight, reduction)
return loss


Expand Down
2 changes: 1 addition & 1 deletion keras/src/losses/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def test_dtype_arg(self):

# `dtype` setter should raise AttributeError
with self.assertRaises(AttributeError):
loss.dtype = "bfloat16"
loss_fn.dtype = "bfloat16"

def test_default_dtype(self):
y_true = np.array([1.0, 0.0, 1.0, 0.0], dtype="float32")
Expand Down
Loading