Skip to content

Commit

Permalink
[fix] Fix torch_compile=True by always inserting a wrapped model …
Browse files Browse the repository at this point in the history
…into the loss (#2884)

This should make the torch_compile=True training argument start working
  • Loading branch information
tomaarsen authored Aug 30, 2024
1 parent 76d70f6 commit 52bf210
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from transformers.integrations import WandbCallback
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import EvalLoopOutput
from transformers.training_args import ParallelMode

from sentence_transformers.data_collator import SentenceTransformerDataCollator
from sentence_transformers.evaluation import SentenceEvaluator, SequentialEvaluator
Expand Down Expand Up @@ -317,12 +316,13 @@ def compute_loss(
if isinstance(loss_fn, dict) and dataset_name:
loss_fn = loss_fn[dataset_name]

# Hackishly insert the distributed model into the loss function, if the loss stores the model
# Only called once per process
# Insert the wrapped (e.g. distributed or compiled) model into the loss function,
# if the loss stores the model. Only called once per process
if (
self.args.parallel_mode != ParallelMode.NOT_PARALLEL
and hasattr(model, "module")
and hasattr(loss_fn, "model")
model == self.model_wrapped
and model != self.model # Only if the model is wrapped
and hasattr(loss_fn, "model") # Only if the loss stores the model
and loss_fn.model != model # Only if the wrapped model is not already stored
):
loss_fn = self.override_model_in_loss(loss_fn, model)
loss = loss_fn(features, labels)
Expand Down

0 comments on commit 52bf210

Please sign in to comment.