Skip to content

Commit

Permalink
Convert torch_dtype as str to actual torch data type (i.e. "float…
Browse files Browse the repository at this point in the history
…16" …to `torch.float16`) (#28208)

* Convert torch_dtype as str to actual torch data type (i.e. "float16" to torch.float16)

* Check if passed torch_dtype is an attribute in torch

* Update src/transformers/pipelines/__init__.py

Check type via isinstance

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent 102549f commit 51a06c1
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,8 @@ def pipeline(
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
" arguments might conflict, use only one.)"
)
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
model_kwargs["torch_dtype"] = torch_dtype

model_name = model if isinstance(model, str) else None
Expand Down

0 comments on commit 51a06c1

Please sign in to comment.