Skip to content

Commit d973e62

Browse files
Vaibhavs10ydshieh
andauthored
fix condition where torch_dtype auto collides with model_kwargs. (#39054)
* fix condition where torch_dtype auto collides with model_kwargs. * update tests * update comment * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 44b2316 commit d973e62

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

src/transformers/pipelines/__init__.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,13 +1005,21 @@ def pipeline(
10051005
model_kwargs["device_map"] = device_map
10061006
if torch_dtype is not None:
10071007
if "torch_dtype" in model_kwargs:
1008-
raise ValueError(
1009-
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
1010-
" arguments might conflict, use only one.)"
1011-
)
1012-
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
1013-
torch_dtype = getattr(torch, torch_dtype)
1014-
model_kwargs["torch_dtype"] = torch_dtype
1008+
# If the user did not explicitly provide `torch_dtype` (i.e. the function default "auto" is still
1009+
# present) but a value is supplied inside `model_kwargs`, we silently defer to the latter instead of
1010+
# raising. This prevents false positives like providing `torch_dtype` only via `model_kwargs` while the
1011+
# top-level argument keeps its default value "auto".
1012+
if torch_dtype == "auto":
1013+
torch_dtype = None
1014+
else:
1015+
raise ValueError(
1016+
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
1017+
" arguments might conflict, use only one.)"
1018+
)
1019+
if torch_dtype is not None:
1020+
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
1021+
torch_dtype = getattr(torch, torch_dtype)
1022+
model_kwargs["torch_dtype"] = torch_dtype
10151023

10161024
model_name = model if isinstance(model, str) else None
10171025

tests/pipelines/test_pipelines_image_text_to_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,11 @@ def test_small_model_pt_token(self):
161161
[
162162
{
163163
"input_text": "<image> What this is? Assistant: This is",
164-
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
164+
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
165165
},
166166
{
167167
"input_text": "<image> What this is? Assistant: This is",
168-
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
168+
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
169169
},
170170
],
171171
)

tests/pipelines/test_pipelines_text_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,11 +441,11 @@ def test_small_model_pt_bloom_accelerate(self):
441441
[{"generated_text": ("This is a test test test test test test")}],
442442
)
443443

444-
# torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602
444+
# torch_dtype will be automatically set to torch.bfloat16 if not provided - check: https://github.com/huggingface/transformers/pull/38882
445445
pipe = pipeline(
446446
model="hf-internal-testing/tiny-random-bloom", device_map="auto", max_new_tokens=5, do_sample=False
447447
)
448-
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32)
448+
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
449449
out = pipe("This is a test")
450450
self.assertEqual(
451451
out,

0 commit comments

Comments
 (0)