Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,27 +1430,6 @@ def compute_transition_scores(

return transition_scores

def _validate_model_class(self):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
# TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from
# `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can
# safely call `GenerationMixin.generate`
if not self.can_generate():
terminations_with_generation_support = [
"ForCausalLM",
"ForConditionalGeneration",
"ForSpeechSeq2Seq",
"ForVision2Seq",
]
raise TypeError(
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head. Classes that support generation often end in one of these "
f"names: {terminations_with_generation_support}."
)

def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer):
if assistant_model is None:
return
Expand Down Expand Up @@ -2213,7 +2192,6 @@ def generate(
"""

# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation

Expand Down
14 changes: 6 additions & 8 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .generation import CompileConfig, GenerationConfig
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.accelerate import find_tied_parameters, init_empty_weights
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
Expand Down Expand Up @@ -1704,8 +1704,7 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


# TODO (joao): remove `GenerationMixin` inheritance in v4.50
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.

Expand Down Expand Up @@ -2157,12 +2156,12 @@ def can_generate(cls) -> bool:
continue
if "PreTrainedModel" not in str(base) and base.can_generate():
return True
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# was how we detected whether a model could generate.
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
logger.warning_once(
if hasattr(cls, "prepare_inputs_for_generation"): # implicit: doesn't inherit `GenerationMixin`
logger.warning(
f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
"overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
"defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
"`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
"to call `generate` and other related functions."
"\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
Expand All @@ -2172,7 +2171,6 @@ def can_generate(cls) -> bool:
"\n - If you are not the owner of the model architecture class, please contact the model code owner "
"to update it."
)
return True
# Otherwise, can't generate
return False

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,12 @@ def add_generation_mixin_to_remote_model(model_class):

# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
# `prepare_inputs_for_generation` method.
has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate"))
has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation"))
has_custom_generate = hasattr(model_class, "generate") and "GenerationMixin" not in str(
getattr(model_class, "generate")
)
has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
getattr(model_class, "prepare_inputs_for_generation")
)
if has_custom_generate or has_custom_prepare_inputs:
model_class_with_generation_mixin = type(
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,8 +1512,8 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_
@classmethod
def can_generate(cls) -> bool:
"""
Legacy correction: BertForMaskedLM can't call `generate()` from GenerationMixin.
Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`.
Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
`prepare_inputs_for_generation` method.
"""
return False

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/ernie/modeling_ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,8 +1328,8 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_
@classmethod
def can_generate(cls) -> bool:
"""
Legacy correction: ErnieForMaskedLM can't call `generate()` from GenerationMixin.
Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`.
Legacy correction: ErnieForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
`prepare_inputs_for_generation` method.
"""
return False

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch import nn

from ...configuration_utils import PretrainedConfig
from ...generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
Expand Down Expand Up @@ -1122,7 +1122,7 @@ def _cat_and_pad(tensors, pad_token_id):
""",
RAG_START_DOCSTRING,
)
class RagTokenForGeneration(RagPreTrainedModel):
class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
def __init__(
self,
config: Optional[PretrainedConfig] = None,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/models/rembert/modeling_rembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,14 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_

return {"input_ids": input_ids, "attention_mask": attention_mask}

@classmethod
def can_generate(cls) -> bool:
"""
Legacy correction: RemBertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
`prepare_inputs_for_generation` method.
"""
return False


@add_start_docstrings(
"""RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING
Expand Down
41 changes: 2 additions & 39 deletions src/transformers/models/speecht5/modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
Expand Down Expand Up @@ -2242,7 +2243,7 @@ def forward(
"""SpeechT5 Model with a speech encoder and a text decoder.""",
SPEECHT5_START_DOCSTRING,
)
class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SpeechT5ForSpeechToText probably had some generate compatibility issues in the last versions, since its prepare_inputs_for_generation was not being updated

(removing the global mixin triggered test failures)

_tied_weights_keys = ["text_decoder_postnet.lm_head.weight"]

def __init__(self, config: SpeechT5Config):
Expand Down Expand Up @@ -2413,44 +2414,6 @@ def forward(
encoder_attentions=outputs.encoder_attentions,
)

def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# Note that this model doesn't inherit from the generation mixin, has unique generate function

# cut decoder_input_ids if past is used
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
Expand Down
26 changes: 23 additions & 3 deletions tests/models/speecht5/test_modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from transformers.trainer_utils import set_seed
from transformers.utils import cached_property

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
ModelTesterMixin,
Expand Down Expand Up @@ -314,6 +315,15 @@ def get_config(self):
vocab_size=self.vocab_size,
)

def get_subsampled_output_lengths(self, input_lengths):
"""
Computes the output length of the convolutional layers
"""
for stride in self.conv_stride:
input_lengths = (input_lengths // stride) - 1

return input_lengths

def create_and_check_model_forward(self, config, inputs_dict):
model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval()

Expand Down Expand Up @@ -359,10 +369,8 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):


@require_torch
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase, GenerationTesterMixin):
all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()
# Doesn't run generation tests. TODO eustache/joao: shape checks probably need an update
all_generative_model_classes = ()
is_encoder_decoder = True
test_pruning = False
test_headmasking = False
Expand Down Expand Up @@ -727,6 +735,18 @@ def _mock_init_weights(self, module):
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3)

@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
def test_generate_with_head_masking(self):
pass

@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
def test_generate_without_input_ids(self):
pass

@unittest.skip(reason="Very flaky") # TODO (joao, eustache): have a look at this test
def test_generate_continue_from_past_key_values(self):
pass


@require_torch
@require_sentencepiece
Expand Down
6 changes: 3 additions & 3 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,16 +1720,16 @@ class DummyBertWithParent(DummyBertWithMixin):
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)

# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
# `GenerationMixin`)
# 4 - Legacy: models with a custom `prepare_inputs_for_generation` can generate (it was assumed
# they inherited `GenerationMixin`). Deprecated in v4.45 and removed in v4.51.
class DummyBertWithPrepareInputs(BertModel):
def prepare_inputs_for_generation(self):
pass

with CaptureLogger(logger) as cl:
can_generate = DummyBertWithPrepareInputs.can_generate()
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
self.assertTrue(can_generate)
self.assertFalse(can_generate)

def test_save_and_load_config_with_custom_generation(self):
"""
Expand Down