Skip to content

Commit

Permalink
[from_pretrained] imporve the error message when _no_split_modules
Browse files Browse the repository at this point in the history
…is not defined (huggingface#23861)

* Better warning

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* format line

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
2 people authored and gojiteji committed Jun 5, 2023
1 parent 7e0aeec commit 576755e
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2734,7 +2734,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
target_dtype = torch.int8

if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
raise ValueError(
f"{model.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model"
"class needs to implement the `_no_split_modules` attribute."
)
no_split_modules = model._no_split_modules
if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
raise ValueError(
Expand Down

0 comments on commit 576755e

Please sign in to comment.