diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index ddc4eb850c..be4932e62f 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -265,14 +265,6 @@ def test_generate(self, model_name): # Just check if the generation works _ = model.generate(input_ids, generation_config=generation_config) - def test_raise_error_not_causallm(self): - # Test with a model without a LM head - model_id = "trl-internal-testing/tiny-GPT2LMHeadModel" - # This should raise a ValueError - with self.assertRaises(ValueError): - pretrained_model = AutoModelForCausalLM.from_pretrained(model_id) - _ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer) - def test_transformers_bf16_kwargs(self): r""" Test if the transformers kwargs are correctly passed @@ -283,10 +275,11 @@ def test_transformers_bf16_kwargs(self): for model_name in self.all_model_names: trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16) - lm_head_namings = self.trl_model_class.lm_head_namings + lm_head_namings = ["lm_head", "embed_out", "output_layer"] self.assertTrue( - any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) + any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), + "Can't test the model because it doesn't have any of the expected lm_head namings", ) for lm_head_naming in lm_head_namings: diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py index 0797794013..592879ae3e 100644 --- a/trl/models/modeling_value_head.py +++ b/trl/models/modeling_value_head.py @@ -69,9 +69,6 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): Class attributes: - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This should be set to `transformers.AutoModelForCausalLM` for this class. - - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the - wrapped model. This is set to `("lm_head", "embed_out", "output_layer")` for this class but can be changed - for other models in the future - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported by the `ValueHead` class. Currently, the supported args are: - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the @@ -86,7 +83,6 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): """ transformers_parent_class = AutoModelForCausalLM - lm_head_namings = ["lm_head", "embed_out", "output_layer"] supported_args = ( "summary_dropout_prob", "v_head_initializer_range", @@ -106,12 +102,7 @@ def __init__(self, pretrained_model, **kwargs): """ super().__init__(pretrained_model, **kwargs) v_head_kwargs, _, _ = self._split_kwargs(kwargs) - - if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings): - raise ValueError("The model does not have a language model head, please use a model that has one.") - self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) - self._init_weights(**v_head_kwargs) def _init_weights(self, **kwargs):