@@ -1387,6 +1387,7 @@ def on_predict_batch_end(self, *args, **kwargs):
13871387
13881388    @pytest .mark .requires_trainable_backend  
13891389    def  test_metric_update_in_compute_loss (self ):
1390+         test_self  =  self 
13901391
13911392        class  MyModel (keras .Model ):
13921393            def  __init__ (self ):
@@ -1398,9 +1399,17 @@ def call(self, x):
13981399                return  self .dense (x )
13991400
14001401            def  compute_loss (
1401-                 self , x = None , y = None , y_pred = None , sample_weight = None 
1402+                 self ,
1403+                 x = None ,
1404+                 y = None ,
1405+                 y_pred = None ,
1406+                 sample_weight = None ,
1407+                 training = True ,
14021408            ):
1403-                 loss  =  super ().compute_loss (x , y , y_pred , sample_weight )
1409+                 test_self .assertTrue (training )
1410+                 loss  =  super ().compute_loss (
1411+                     x , y , y_pred , sample_weight , training 
1412+                 )
14041413                self .custom_metric .update_state (loss  *  4 )
14051414                return  loss 
14061415
@@ -1415,6 +1424,7 @@ def compute_loss(
14151424
14161425    @pytest .mark .requires_trainable_backend  
14171426    def  test_fwd_pass_loss_presence_in_compute_loss (self ):
1427+         test_self  =  self 
14181428
14191429        class  MyModel (keras .Model ):
14201430            def  __init__ (self ):
@@ -1426,9 +1436,17 @@ def call(self, x):
14261436                return  self .dense (x )
14271437
14281438            def  compute_loss (
1429-                 self , x = None , y = None , y_pred = None , sample_weight = None 
1439+                 self ,
1440+                 x = None ,
1441+                 y = None ,
1442+                 y_pred = None ,
1443+                 sample_weight = None ,
1444+                 training = True ,
14301445            ):
1431-                 loss  =  super ().compute_loss (x , y , y_pred , sample_weight )
1446+                 test_self .assertTrue (training )
1447+                 loss  =  super ().compute_loss (
1448+                     x , y , y_pred , sample_weight , training 
1449+                 )
14321450                self .custom_metric .update_state (sum (self .losses ))
14331451                return  loss 
14341452
@@ -1439,6 +1457,75 @@ def compute_loss(
14391457        history  =  model .fit (x , y )
14401458        self .assertGreater (history .history ["custom" ][0 ], 0.0 )
14411459
1460+     @pytest .mark .requires_trainable_backend  
1461+     def  test_evaluate_with_custom_compute_loss (self ):
1462+         test_self  =  self 
1463+ 
1464+         class  MyModel (keras .Model ):
1465+             def  __init__ (self ):
1466+                 super ().__init__ ()
1467+                 self .custom_metric  =  keras .metrics .Mean (name = "custom" )
1468+                 self .dense  =  keras .layers .Dense (2 , activity_regularizer = "l2" )
1469+ 
1470+             def  call (self , x ):
1471+                 return  self .dense (x )
1472+ 
1473+             def  compute_loss (
1474+                 self ,
1475+                 x = None ,
1476+                 y = None ,
1477+                 y_pred = None ,
1478+                 sample_weight = None ,
1479+                 training = True ,
1480+             ):
1481+                 test_self .assertFalse (training )
1482+                 loss  =  super ().compute_loss (
1483+                     x , y , y_pred , sample_weight , training 
1484+                 )
1485+                 self .custom_metric .update_state (loss  *  4 )
1486+                 return  loss 
1487+ 
1488+         model  =  MyModel ()
1489+         model .compile (optimizer = "sgd" , loss = "mse" )
1490+         x  =  np .ones ((32 , 4 ))
1491+         y  =  np .ones ((32 , 2 )) *  2 
1492+         logs  =  model .evaluate (x , y , return_dict = True )
1493+         self .assertAlmostEqual (logs ["custom" ], logs ["loss" ] *  4 )
1494+ 
1495+     @pytest .mark .requires_trainable_backend  
1496+     def  test_compute_loss_no_training_backwards_compatibility (self ):
1497+ 
1498+         class  MyModel (keras .Model ):
1499+             def  __init__ (self ):
1500+                 super ().__init__ ()
1501+                 self .custom_metric  =  keras .metrics .Mean (name = "custom" )
1502+                 self .dense  =  keras .layers .Dense (2 , activity_regularizer = "l2" )
1503+ 
1504+             def  call (self , x ):
1505+                 return  self .dense (x )
1506+ 
1507+             def  compute_loss (
1508+                 self ,
1509+                 x = None ,
1510+                 y = None ,
1511+                 y_pred = None ,
1512+                 sample_weight = None ,
1513+             ):
1514+                 loss  =  super ().compute_loss (x , y , y_pred , sample_weight )
1515+                 self .custom_metric .update_state (loss  *  4 )
1516+                 return  loss 
1517+ 
1518+         model  =  MyModel ()
1519+         model .compile (optimizer = "sgd" , loss = "mse" )
1520+         x  =  np .ones ((32 , 4 ))
1521+         y  =  np .ones ((32 , 2 )) *  2 
1522+         logs  =  model .evaluate (x , y , return_dict = True )
1523+         self .assertAlmostEqual (logs ["custom" ], logs ["loss" ] *  4 )
1524+         history  =  model .fit (x , y )
1525+         self .assertAlmostEqual (
1526+             history .history ["custom" ][0 ], history .history ["loss" ][0 ] *  4 
1527+         )
1528+ 
14421529    @pytest .mark .requires_trainable_backend  
14431530    def  test_loss_weights (self ):
14441531        epochs  =  3 
0 commit comments