Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 74 additions & 36 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import unittest
import warnings
from pathlib import Path
from typing import Optional

import numpy as np
import pytest
Expand Down Expand Up @@ -927,32 +928,44 @@ def test_prompt_lookup_decoding_stops_at_eos(self):
self.assertTrue(output_prompt_lookup.shape[-1] == 10)

@pytest.mark.generate
def test_left_padding_compatibility(self):
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
def test_left_padding_compatibility(
Copy link
Contributor Author

@gante gante Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL;DR

  • better documented
  • support of custom inputs args (padded and unpadded), for much simpler overwrites
  • also prepare padded token_type_ids
  • forward other model kwargs

self, unpadded_custom_inputs: Optional[dict] = None, padded_custom_inputs: Optional[dict] = None
):
"""
Tests that adding left-padding yields the same logits as the original input. Exposes arguments for custom
inputs for overwrites, to prevent full rewrites of the test when all we need is model-specific input handling.

# First, filter out models that don't support left padding
# - The model must have generative capabilities
if len(self.all_generative_model_classes) == 0:
self.skipTest(reason="No generative architecture available for this model.")
! If you overwrite this test, make sure to document why you need to overwrite it !

# - The model must support padding
NOTE: left-padding results in small numerical differences. This is expected.
See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535

Args:
unpadded_custom_inputs (`dict`, *optional*):
Used in test overwrites. Custom inputs to add/overwrite over the default test inputs.
padded_custom_inputs (`dict`, *optional*):
Used in test overwrites. Custom inputs to add/overwrite over the padded test input handcrafted in this
test. Commonly used e.g. with multimodal cross attention masks.
"""

# First, filter out models that don't support left padding
# 1. The model must support padding
if not self.has_attentions:
self.skipTest(reason="This model doesn't support padding.")

# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
# 2. [encoder-decoder] The model must be a decoder-only architecture. Encoder-based architectures can use
# right-padding in their (encoder) inputs. Encoder-decoder may use left-padding on their decoder inputs
# [TODO: lift this restriction? technically, we can test padding the decoder inputs.]
decoder_only_classes = []
for model_class in self.all_generative_model_classes:
config, _ = self.prepare_config_and_inputs_for_generate()
if config.get_text_config(decoder=True).is_encoder_decoder:
if config.is_encoder_decoder:
continue
else:
decoder_only_classes.append(model_class)
if len(decoder_only_classes) == 0:
self.skipTest(reason="No decoder-only architecture available for this model.")

# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
# added support for it yet. We skip these models for now.
# 3. [old models] Decoder-only architectures derived from encoder-decoder models could support it in theory,
# but we haven't added support for it yet. We skip these models for now.
has_encoder_attributes = any(
attr_name
for attr_name in config.to_dict()
Expand All @@ -963,48 +976,73 @@ def test_left_padding_compatibility(self):
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
)

# Then, test left-padding
def _prepare_model_kwargs(input_ids, attention_mask, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
# Now we can start testing
unpadded_custom_inputs = unpadded_custom_inputs or {}
padded_custom_inputs = padded_custom_inputs or {}

def _prepare_model_kwargs(model_inputs, signature):
model_kwargs = {"input_ids": model_inputs["input_ids"], "attention_mask": model_inputs["attention_mask"]}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = torch.cumsum(model_inputs["attention_mask"], dim=-1) - 1
position_ids.masked_fill_(model_inputs["attention_mask"] == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[1], device=torch_device)
cache_position = torch.arange(model_inputs["input_ids"].shape[1], device=torch_device)
model_kwargs["cache_position"] = cache_position
# forward all other inputs, if they are in the signature
model_kwargs.update({k: v for k, v in model_inputs.items() if k not in model_kwargs and k in signature})
return model_kwargs

for model_class in decoder_only_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict.get("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()

# no cache as some models require special cache classes to be init outside forward
# No cache to simplify the test (some models need careful init)
model.generation_config.use_cache = False
inputs_dict.update(unpadded_custom_inputs)
# special case: an inexistent `attention_mask` is a full mask
inputs_dict["attention_mask"] = inputs_dict.get("attention_mask", None)
if inputs_dict["attention_mask"] is None:
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["input_ids"])

# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
# Get output logits from inputs without padding
model_kwargs_wo_padding = _prepare_model_kwargs(inputs_dict, signature)
next_logits_wo_padding = model(**model_kwargs_wo_padding).logits[:, -1, :]

# With left-padding (length 32)
# can hardcode pad_token to be 0 as we'll do attn masking anyway
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
# Prepare padding on common inputs (pad length 32)
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
token_type_ids = inputs_dict.get("token_type_ids", None)
pad_token_id = getattr(config.get_text_config(decoder=True), "pad_token_id", None) or 0
pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:])
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat(
(torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1
)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
if token_type_ids is not None:
padded_token_type_ids = torch.cat(
(
# Assumption: `0` is a good default value for padding token type ids
torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device),
token_type_ids,
),
dim=1,
)
else:
padded_token_type_ids = None

# Get output logits from inputs with left-padding (pad length 32)
padded_inputs_dict = copy.deepcopy(inputs_dict)
padded_inputs_dict["input_ids"] = padded_input_ids
padded_inputs_dict["attention_mask"] = padded_attention_mask
if padded_token_type_ids is not None:
padded_inputs_dict["token_type_ids"] = padded_token_type_ids
padded_inputs_dict.update(padded_custom_inputs)

model_kwargs_with_padding = _prepare_model_kwargs(padded_inputs_dict, signature)
next_logits_with_padding = model(**model_kwargs_with_padding).logits[:, -1, :]

# They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
Expand Down
83 changes: 3 additions & 80 deletions tests/models/bamba/test_modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,88 +438,11 @@ def test_batching_equivalence(self):
super().test_batching_equivalence()
self.model_tester.use_input_mask = orig

# essentially the same test in test_utils, just adjustment for rtol for this model
Copy link
Contributor Author

@gante gante Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this overwrite comment wasn't even true 😢

@pytest.mark.generate
def test_left_padding_compatibility(self):
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535

# First, filter out models that don't support left padding
# - The model must have generative capabilities
if len(self.all_generative_model_classes) == 0:
self.skipTest(reason="No generative architecture available for this model.")

# - The model must support padding
if not self.has_attentions:
self.skipTest(reason="This model doesn't support padding.")

# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
decoder_only_classes = []
for model_class in self.all_generative_model_classes:
config, _ = self.prepare_config_and_inputs_for_generate()
if config.is_encoder_decoder:
continue
else:
decoder_only_classes.append(model_class)
if len(decoder_only_classes) == 0:
self.skipTest(reason="No decoder-only architecture available for this model.")

# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
# added support for it yet. We skip these models for now.
has_encoder_attributes = any(
attr_name
for attr_name in config.to_dict()
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
)
if has_encoder_attributes:
self.skipTest(
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
)

# Then, test left-padding
def _prepare_model_kwargs(input_ids, attention_mask, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs

for model_class in decoder_only_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
input_ids = inputs_dict["input_ids"]

# - for left padding we absolutely need to use an all ones
# attention mask, so we do not use the one in inputs_dict
attention_mask = torch.ones_like(input_ids)

model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()

# no cache as some models require special cache classes to be init outside forward
model.generation_config.use_cache = False

# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]

# With left-padding (length 32)
# can hardcode pad_token to be 0 as we'll do attn masking anyway
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]

# They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
# TODO: document why a random attention mask causes this test to fail, but a full mask doesn't
unpadded_custom_inputs = {"attention_mask": None}
super().test_left_padding_compatibility(unpadded_custom_inputs=unpadded_custom_inputs)

@unittest.skip(
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
Expand Down
84 changes: 0 additions & 84 deletions tests/models/blip_2/test_modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import unittest

import numpy as np
import pytest
import requests
from parameterized import parameterized

Expand Down Expand Up @@ -597,89 +596,6 @@ def _check_generate_outputs(self, output, config, use_cache=False, num_return_se
output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams
)

# overwrite because BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present
@pytest.mark.generate
def test_left_padding_compatibility(self):
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535

# First, filter out models that don't support left padding
# - The model must have generative capabilities
if len(self.all_generative_model_classes) == 0:
self.skipTest(reason="No generative architecture available for this model.")

# - The model must support padding
if not self.has_attentions:
self.skipTest(reason="This model doesn't support padding.")

# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
decoder_only_classes = []
for model_class in self.all_generative_model_classes:
config, _ = self.prepare_config_and_inputs_for_generate()
if config.is_encoder_decoder:
continue
else:
decoder_only_classes.append(model_class)
if len(decoder_only_classes) == 0:
self.skipTest(reason="No decoder-only architecture available for this model.")

# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
# added support for it yet. We skip these models for now.
has_encoder_attributes = any(
attr_name
for attr_name in config.to_dict()
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
)
if has_encoder_attributes:
self.skipTest(
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
)

# Then, test left-padding
def _prepare_model_kwargs(input_ids, attention_mask, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs

for model_class in decoder_only_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict.get("attention_mask")
pixel_values = inputs_dict["pixel_values"]
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()

# no cache as some models require special cache classes to be init outside forward
model.generation_config.use_cache = False

# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs, pixel_values=pixel_values).logits[:, -1, :]

# With left-padding (length 32)
# can hardcode pad_token to be 0 as we'll do attn masking anyway
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs, pixel_values=pixel_values).logits[:, -1, :]

# They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)


# this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py
class Blip2TextModelTester:
Expand Down
Loading