File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -623,9 +623,7 @@ def __init__(
623623 else unwrapped_model .get_base_model ().forward
624624 )
625625 forward_params = inspect .signature (model_forward ).parameters
626- self .model_accepts_loss_kwargs = (
627- "loss_kwargs" in forward_params and forward_params ["loss_kwargs" ].kind == inspect .Parameter .VAR_KEYWORD
628- )
626+ self .model_accepts_loss_kwargs = any (k .kind == inspect .Parameter .VAR_KEYWORD for k in forward_params .values ())
629627
630628 self .neftune_noise_alpha = args .neftune_noise_alpha
631629
@@ -3651,7 +3649,10 @@ def training_step(
36513649 return loss_mb .reduce_mean ().detach ().to (self .args .device )
36523650
36533651 with self .compute_loss_context_manager ():
3654- loss = self .compute_loss (model , inputs , num_items_in_batch = num_items_in_batch )
3652+ if self .model_accepts_loss_kwargs :
3653+ loss = self .compute_loss (model , inputs )
3654+ else :
3655+ loss = self .compute_loss (model , inputs , num_items_in_batch = num_items_in_batch )
36553656
36563657 del inputs
36573658 if (
You can’t perform that action at this time.
0 commit comments