Skip to content

Commit

Permalink
[fix] Fix trainer.train(resume_from_checkpoint="...") with custom m…
Browse files Browse the repository at this point in the history
…odels (i.e. `trust_remote_code`) (#2918)

i.e. models that require trust_remote_code=True
  • Loading branch information
tomaarsen authored Sep 9, 2024
1 parent 0a32ec8 commit ac38c1c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
self.prompts = prompts or {}
self.default_prompt_name = default_prompt_name
self.similarity_fn_name = similarity_fn_name
self.trust_remote_code = trust_remote_code
self.truncate_dim = truncate_dim
self.model_card_data = model_card_data or SentenceTransformerModelCardData()
self._model_card_vars = {}
Expand Down
2 changes: 1 addition & 1 deletion sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def _save(self, output_dir: str | None = None, state_dict=None) -> None:
def _load_from_checkpoint(self, checkpoint_path: str) -> None:
from sentence_transformers import SentenceTransformer

loaded_model = SentenceTransformer(checkpoint_path)
loaded_model = SentenceTransformer(checkpoint_path, trust_remote_code=self.model.trust_remote_code)
self.model.load_state_dict(loaded_model.state_dict())

def create_model_card(
Expand Down

0 comments on commit ac38c1c

Please sign in to comment.