Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring back loss information for multiple outputs #20023

Merged
merged 7 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Bring back loss info for multiple outputs
  • Loading branch information
james77777778 committed Jul 21, 2024
commit 71a360b98e561d44c947e9143d61c660ec79f084
47 changes: 24 additions & 23 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,13 @@ def test_functional_list_outputs_list_losses(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
]
)
Expand All @@ -270,16 +269,15 @@ def test_functional_list_outputs_list_losses_abbr(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_bce",
"output_a_mae",
"output_a_mse",
"output_b_acc",
# "output_b_loss",
"output_b_loss",
"output_b_mse",
]
)
Expand All @@ -303,14 +301,13 @@ def test_functional_list_outputs_nested_list_losses(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
]
)
Expand Down Expand Up @@ -351,15 +348,14 @@ def test_functional_dict_outputs_dict_losses(self):
verbose=0,
)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
"output_b_weighted_accuracy",
"output_b_weighted_mean_squared_error",
Expand Down Expand Up @@ -396,15 +392,14 @@ def test_functional_list_outputs_dict_losses_metrics(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
"output_b_weighted_accuracy",
"output_b_weighted_mean_squared_error",
Expand Down Expand Up @@ -436,18 +431,17 @@ def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
# `output_b_accuracy` doesn't have `weighted_` in metric name.
# When a metric is only in weighted metrics, it skips `weighted_`
# prefix. This behavior matches`tf.keras`.
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
]
)
Expand All @@ -472,13 +466,12 @@ def test_functional_list_outputs_dict_losses_partial_metrics(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
]
)
Expand All @@ -500,7 +493,10 @@ def test_functional_dict_outputs_with_single_tensor(self):
"output_b": "binary_crossentropy",
},
)
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
ref_keys = sorted(["loss", "output_a_loss", "output_b_loss"])
self.assertListEqual(hist_keys, ref_keys)

def test_functional_list_outputs_with_custom_compute_loss(self):
model = _get_model_with_custom_compute_loss()
Expand All @@ -514,7 +510,12 @@ def test_functional_list_outputs_with_custom_compute_loss(self):
model.compile(
optimizer="sgd", loss=["mean_squared_error", "binary_crossentropy"]
)
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
ref_keys = sorted(
["binary_crossentropy_loss", "loss", "mean_squared_error_loss"]
)
self.assertListEqual(hist_keys, ref_keys)

def test_functional_list_outputs_dict_losses_invalid_keys(self):
model = _get_model_multi_outputs_list()
Expand Down
67 changes: 64 additions & 3 deletions keras/src/trainers/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from keras.src import ops
from keras.src import tree
from keras.src.utils.naming import get_object_name
from keras.src.utils.tracking import Tracker


class MetricsList(metrics_module.Metric):
Expand Down Expand Up @@ -431,6 +432,39 @@ def __init__(
# Inferred by `y_pred` and `output_names`
self.inferred_output_names = None

# Use `Tracker` to track metrcis for individual losses.
self._metrics = []
self._tracker = Tracker(
{
"metrics": (
lambda x: isinstance(x, metrics_module.Metric),
self._metrics,
)
}
)

@property
def metrics(self):
if not self.built:
return []
metrics = []
for m in self._metrics:
if m is not None:
metrics.append(m)
return metrics

@property
def variables(self):
# Avoiding relying on implicit tracking since
# CompileLoss may be instantiated or built in a no tracking scope.
if not self.built:
return []
vars = []
for m in self.metrics:
if m is not None:
vars.extend(m.variables)
return vars

def build(self, y_true, y_pred):
loss = self._user_loss
loss_weights = self._user_loss_weights
Expand Down Expand Up @@ -527,6 +561,21 @@ def build(self, y_true, y_pred):
for identifier, _y_true, _y_pred in zip(flat_losses, y_true, y_pred)
]

# Add `Mean` metric to the tracker for each loss.
if len(flat_losses) > 1:
for i, loss in enumerate(flat_losses):
if loss is not None:
if inferred_output_names is not None and len(
inferred_output_names
) == len(flat_losses):
name = inferred_output_names[i]
else:
name = loss.name
name += "_loss"
self._tracker.add_to_store(
"metrics", metrics_module.Mean(name=name)
)

self.flat_losses = flat_losses
self.flat_loss_weights = flat_loss_weights
self.filtered_y_true_keys = filtered_y_true_keys
Expand Down Expand Up @@ -595,23 +644,35 @@ def call(self, y_true, y_pred, sample_weight=None):
sample_weight = [sample_weight[0] for _ in range(len(y_true))]
else:
sample_weight = [None for _ in y_true]
if len(self.metrics) == 0:
# This means that the model has a single output. We need to add a
# dummy `None` for the following `zip` to function correctly.
metrics = [None]
else:
metrics = self.metrics

# Iterate all losses in flat form.
loss_values = []
for loss, y_t, y_p, loss_weight, sample_weight in zip(
for loss_fn, y_t, y_p, loss_weight, sample_weight, metric in zip(
self.flat_losses,
y_true,
y_pred,
self.flat_loss_weights,
sample_weight,
metrics,
):
if loss:
if loss_fn:
value = ops.cast(
loss(y_t, y_p, sample_weight), dtype=self.dtype
loss_fn(y_t, y_p, sample_weight), dtype=self.dtype
)
if loss_weight is not None:
value = ops.multiply(value, loss_weight)
loss_values.append(value)
# Record individual losses.
if metric:
metric.update_state(
value, sample_weight=tree.flatten(y_p)[0].shape[0]
)
if loss_values:
total_loss = sum(loss_values)
return total_loss
Expand Down
16 changes: 15 additions & 1 deletion keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def metrics(self):
metrics.extend(super().metrics)
if self.compiled and self._compile_metrics is not None:
metrics += [self._compile_metrics]
if self.compiled and self._compile_loss is not None:
metrics.extend(self._compile_loss.metrics)
return metrics

@property
Expand Down Expand Up @@ -1004,10 +1006,13 @@ def _symbolic_build(self, iterator=None, data_batch=None):
self._compile_metrics is not None
and not self._compile_metrics.built
)
compile_loss_unbuilt = (
self._compile_loss is not None and not self._compile_loss.built
)
optimizer_unbuilt = (
self.optimizer is not None and not self.optimizer.built
)
if model_unbuilt or compile_metrics_unbuilt:
if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt:
# Create symbolic tensors matching an input batch.

def to_symbolic_input(v):
Expand Down Expand Up @@ -1052,6 +1057,15 @@ def to_symbolic_input(v):
y_pred,
sample_weight=sample_weight,
)
if compile_loss_unbuilt:
# Build `CompileLoss` state with `backend.compute_output_spec`.
backend.compute_output_spec(
self._compute_loss,
x,
y,
y_pred,
sample_weight=sample_weight,
)
if optimizer_unbuilt:
# Build optimizer
self.optimizer.build(self.trainable_variables)
Expand Down
Loading