Skip to content

Commit 46df341

Browse files
authored
Add training argument to Model.compute_loss(). (#19840)
This allows models to perform different computations during training and evaluation. For instance, some expensive to compute metrics can be skipped during training and only computed during evaluation. Note that backwards compatibility with overrides that do not have the `training` argument is maintained.
1 parent 9ad3ca0 commit 46df341

File tree

6 files changed

+144
-21
lines changed

6 files changed

+144
-21
lines changed

keras/src/backend/jax/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def compute_loss_and_updates(
6363
y=y,
6464
y_pred=y_pred,
6565
sample_weight=sample_weight,
66+
training=training,
6667
)
6768
if losses:
6869
self._losses_override.clear()

keras/src/backend/numpy/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def test_step(self, data):
2828
y_pred = self(x, training=False)
2929
else:
3030
y_pred = self(x)
31-
loss = self.compute_loss(
32-
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
31+
loss = self._compute_loss(
32+
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False
3333
)
3434
self._loss_tracker.update_state(
3535
loss, sample_weight=tree.flatten(x)[0].shape[0]

keras/src/backend/tensorflow/trainer.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,12 @@ def train_step(self, data):
5151
y_pred = self(x, training=True)
5252
else:
5353
y_pred = self(x)
54-
loss = self.compute_loss(
55-
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
54+
loss = self._compute_loss(
55+
x=x,
56+
y=y,
57+
y_pred=y_pred,
58+
sample_weight=sample_weight,
59+
training=True,
5660
)
5761
self._loss_tracker.update_state(
5862
loss, sample_weight=tf.shape(tree.flatten(x)[0])[0]
@@ -78,8 +82,8 @@ def test_step(self, data):
7882
y_pred = self(x, training=False)
7983
else:
8084
y_pred = self(x)
81-
loss = self.compute_loss(
82-
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
85+
loss = self._compute_loss(
86+
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False
8387
)
8488
self._loss_tracker.update_state(
8589
loss, sample_weight=tf.shape(tree.flatten(x)[0])[0]
@@ -601,17 +605,17 @@ def compiled_loss(
601605
self, y, y_pred, sample_weight=None, regularization_losses=None
602606
):
603607
warnings.warn(
604-
"`model.compiled_loss()` is deprecated. "
605-
"Instead, use `model.compute_loss(x, y, y_pred, sample_weight)`.",
608+
"`model.compiled_loss()` is deprecated. Instead, use "
609+
"`model.compute_loss(x, y, y_pred, sample_weight, training)`.",
606610
)
607611
return self.compute_loss(
608612
x=None, y=y, y_pred=y_pred, sample_weight=sample_weight
609613
)
610614

611615
def loss(self, y, y_pred, sample_weight=None):
612616
warnings.warn(
613-
"`model.loss` is deprecated. "
614-
"Instead, use `model.compute_loss(x, y, y_pred, sample_weight)`.",
617+
"`model.loss()` is deprecated. Instead, use "
618+
"`model.compute_loss(x, y, y_pred, sample_weight, training)`.",
615619
)
616620
return self.compute_loss(
617621
x=None, y=y, y_pred=y_pred, sample_weight=sample_weight

keras/src/backend/torch/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def train_step(self, data):
4949
# for the weights from the previous train step.
5050
self.zero_grad()
5151

52-
loss = self.compute_loss(
53-
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
52+
loss = self._compute_loss(
53+
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=True
5454
)
5555
self._loss_tracker.update_state(
5656
loss, sample_weight=tree.flatten(x)[0].shape[0]
@@ -85,8 +85,8 @@ def test_step(self, data):
8585
y_pred = self(x, training=False)
8686
else:
8787
y_pred = self(x)
88-
loss = self.compute_loss(
89-
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
88+
loss = self._compute_loss(
89+
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False
9090
)
9191
self._loss_tracker.update_state(
9292
loss, sample_weight=tree.flatten(x)[0].shape[0]

keras/src/trainers/trainer.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import platform
23
import warnings
34

@@ -25,6 +26,9 @@ def __init__(self):
2526
self.steps_per_execution = 1
2627
# Can be set by callbacks in on_train_begin
2728
self._initial_epoch = None
29+
self._compute_loss_has_training_arg = (
30+
"training" in inspect.signature(self.compute_loss).parameters
31+
)
2832

2933
@traceback_utils.filter_traceback
3034
@tracking.no_automatic_dependency_tracking
@@ -262,6 +266,7 @@ def compute_loss(
262266
y=None,
263267
y_pred=None,
264268
sample_weight=None,
269+
training=True,
265270
):
266271
"""Compute the total loss, validate it, and return it.
267272
@@ -276,7 +281,7 @@ def __init__(self, *args, **kwargs):
276281
super().__init__(*args, **kwargs)
277282
self.loss_tracker = metrics.Mean(name='loss')
278283
279-
def compute_loss(self, x, y, y_pred, sample_weight):
284+
def compute_loss(self, x, y, y_pred, sample_weight, training=True):
280285
loss = ops.means((y_pred - y) ** 2)
281286
loss += ops.sum(self.losses)
282287
self.loss_tracker.update_state(loss)
@@ -306,12 +311,15 @@ def metrics(self):
306311
y: Target data.
307312
y_pred: Predictions returned by the model (output of `model(x)`)
308313
sample_weight: Sample weights for weighting the loss function.
314+
training: Whether we are training or evaluating the model.
309315
310316
Returns:
311317
The total loss as a scalar tensor, or `None` if no loss results
312318
(which is the case when called by `Model.test_step`).
313319
"""
314-
del x # The default implementation does not use `x`.
320+
# The default implementation does not use `x` or `training`.
321+
del x
322+
del training
315323
losses = []
316324
if self._compile_loss is not None:
317325
loss = self._compile_loss(y, y_pred, sample_weight)
@@ -331,6 +339,27 @@ def metrics(self):
331339
total_loss = ops.sum(losses)
332340
return total_loss
333341

342+
def _compute_loss(
343+
self,
344+
x=None,
345+
y=None,
346+
y_pred=None,
347+
sample_weight=None,
348+
training=True,
349+
):
350+
"""Backwards compatibility wrapper for `compute_loss`.
351+
352+
This should be used instead `compute_loss` within `train_step` and
353+
`test_step` to support overrides of `compute_loss` that may not have
354+
the `training` argument, as this argument was added in Keras 3.3.
355+
"""
356+
if self._compute_loss_has_training_arg:
357+
return self.compute_loss(
358+
x, y, y_pred, sample_weight, training=training
359+
)
360+
else:
361+
return self.compute_loss(x, y, y_pred, sample_weight)
362+
334363
def stateless_compute_loss(
335364
self,
336365
trainable_variables,
@@ -340,6 +369,7 @@ def stateless_compute_loss(
340369
y=None,
341370
y_pred=None,
342371
sample_weight=None,
372+
training=True,
343373
):
344374
var_mapping = list(zip(self.trainable_variables, trainable_variables))
345375
var_mapping.extend(
@@ -349,11 +379,12 @@ def stateless_compute_loss(
349379
with backend.StatelessScope(state_mapping=var_mapping) as scope:
350380
# Note that this is needed for the regularization loss, which need
351381
# the latest value of train/non-trainable variables.
352-
loss = self.compute_loss(
382+
loss = self._compute_loss(
353383
x,
354384
y,
355385
y_pred,
356386
sample_weight=sample_weight,
387+
training=training,
357388
)
358389

359390
# Update non trainable vars (may have been updated in compute_loss)

keras/src/trainers/trainer_test.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,7 @@ def on_predict_batch_end(self, *args, **kwargs):
13871387

13881388
@pytest.mark.requires_trainable_backend
13891389
def test_metric_update_in_compute_loss(self):
1390+
test_self = self
13901391

13911392
class MyModel(keras.Model):
13921393
def __init__(self):
@@ -1398,9 +1399,17 @@ def call(self, x):
13981399
return self.dense(x)
13991400

14001401
def compute_loss(
1401-
self, x=None, y=None, y_pred=None, sample_weight=None
1402+
self,
1403+
x=None,
1404+
y=None,
1405+
y_pred=None,
1406+
sample_weight=None,
1407+
training=True,
14021408
):
1403-
loss = super().compute_loss(x, y, y_pred, sample_weight)
1409+
test_self.assertTrue(training)
1410+
loss = super().compute_loss(
1411+
x, y, y_pred, sample_weight, training
1412+
)
14041413
self.custom_metric.update_state(loss * 4)
14051414
return loss
14061415

@@ -1415,6 +1424,7 @@ def compute_loss(
14151424

14161425
@pytest.mark.requires_trainable_backend
14171426
def test_fwd_pass_loss_presence_in_compute_loss(self):
1427+
test_self = self
14181428

14191429
class MyModel(keras.Model):
14201430
def __init__(self):
@@ -1426,9 +1436,17 @@ def call(self, x):
14261436
return self.dense(x)
14271437

14281438
def compute_loss(
1429-
self, x=None, y=None, y_pred=None, sample_weight=None
1439+
self,
1440+
x=None,
1441+
y=None,
1442+
y_pred=None,
1443+
sample_weight=None,
1444+
training=True,
14301445
):
1431-
loss = super().compute_loss(x, y, y_pred, sample_weight)
1446+
test_self.assertTrue(training)
1447+
loss = super().compute_loss(
1448+
x, y, y_pred, sample_weight, training
1449+
)
14321450
self.custom_metric.update_state(sum(self.losses))
14331451
return loss
14341452

@@ -1439,6 +1457,75 @@ def compute_loss(
14391457
history = model.fit(x, y)
14401458
self.assertGreater(history.history["custom"][0], 0.0)
14411459

1460+
@pytest.mark.requires_trainable_backend
1461+
def test_evaluate_with_custom_compute_loss(self):
1462+
test_self = self
1463+
1464+
class MyModel(keras.Model):
1465+
def __init__(self):
1466+
super().__init__()
1467+
self.custom_metric = keras.metrics.Mean(name="custom")
1468+
self.dense = keras.layers.Dense(2, activity_regularizer="l2")
1469+
1470+
def call(self, x):
1471+
return self.dense(x)
1472+
1473+
def compute_loss(
1474+
self,
1475+
x=None,
1476+
y=None,
1477+
y_pred=None,
1478+
sample_weight=None,
1479+
training=True,
1480+
):
1481+
test_self.assertFalse(training)
1482+
loss = super().compute_loss(
1483+
x, y, y_pred, sample_weight, training
1484+
)
1485+
self.custom_metric.update_state(loss * 4)
1486+
return loss
1487+
1488+
model = MyModel()
1489+
model.compile(optimizer="sgd", loss="mse")
1490+
x = np.ones((32, 4))
1491+
y = np.ones((32, 2)) * 2
1492+
logs = model.evaluate(x, y, return_dict=True)
1493+
self.assertAlmostEqual(logs["custom"], logs["loss"] * 4)
1494+
1495+
@pytest.mark.requires_trainable_backend
1496+
def test_compute_loss_no_training_backwards_compatibility(self):
1497+
1498+
class MyModel(keras.Model):
1499+
def __init__(self):
1500+
super().__init__()
1501+
self.custom_metric = keras.metrics.Mean(name="custom")
1502+
self.dense = keras.layers.Dense(2, activity_regularizer="l2")
1503+
1504+
def call(self, x):
1505+
return self.dense(x)
1506+
1507+
def compute_loss(
1508+
self,
1509+
x=None,
1510+
y=None,
1511+
y_pred=None,
1512+
sample_weight=None,
1513+
):
1514+
loss = super().compute_loss(x, y, y_pred, sample_weight)
1515+
self.custom_metric.update_state(loss * 4)
1516+
return loss
1517+
1518+
model = MyModel()
1519+
model.compile(optimizer="sgd", loss="mse")
1520+
x = np.ones((32, 4))
1521+
y = np.ones((32, 2)) * 2
1522+
logs = model.evaluate(x, y, return_dict=True)
1523+
self.assertAlmostEqual(logs["custom"], logs["loss"] * 4)
1524+
history = model.fit(x, y)
1525+
self.assertAlmostEqual(
1526+
history.history["custom"][0], history.history["loss"][0] * 4
1527+
)
1528+
14421529
@pytest.mark.requires_trainable_backend
14431530
def test_loss_weights(self):
14441531
epochs = 3

0 commit comments

Comments
 (0)