Skip to content

Commit

Permalink
🔓 Remove lm_head check in AutoModelForCausalLMWithValueHead (#2398)
Browse files Browse the repository at this point in the history
* Remove lm_head check in `AutoModelForCausalLMWithValueHead`

* Style

* Remove test
  • Loading branch information
qgallouedec authored Nov 29, 2024
1 parent ac26778 commit 94e4135
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 19 deletions.
13 changes: 3 additions & 10 deletions tests/test_modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,6 @@ def test_generate(self, model_name):
# Just check if the generation works
_ = model.generate(input_ids, generation_config=generation_config)

def test_raise_error_not_causallm(self):
# Test with a model without a LM head
model_id = "trl-internal-testing/tiny-GPT2LMHeadModel"
# This should raise a ValueError
with self.assertRaises(ValueError):
pretrained_model = AutoModelForCausalLM.from_pretrained(model_id)
_ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer)

def test_transformers_bf16_kwargs(self):
r"""
Test if the transformers kwargs are correctly passed
Expand All @@ -283,10 +275,11 @@ def test_transformers_bf16_kwargs(self):
for model_name in self.all_model_names:
trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16)

lm_head_namings = self.trl_model_class.lm_head_namings
lm_head_namings = ["lm_head", "embed_out", "output_layer"]

self.assertTrue(
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings),
"Can't test the model because it doesn't have any of the expected lm_head namings",
)

for lm_head_naming in lm_head_namings:
Expand Down
9 changes: 0 additions & 9 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
Class attributes:
- **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This
should be set to `transformers.AutoModelForCausalLM` for this class.
- **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the
wrapped model. This is set to `("lm_head", "embed_out", "output_layer")` for this class but can be changed
for other models in the future
- **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported
by the `ValueHead` class. Currently, the supported args are:
- **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the
Expand All @@ -86,7 +83,6 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
"""

transformers_parent_class = AutoModelForCausalLM
lm_head_namings = ["lm_head", "embed_out", "output_layer"]
supported_args = (
"summary_dropout_prob",
"v_head_initializer_range",
Expand All @@ -106,12 +102,7 @@ def __init__(self, pretrained_model, **kwargs):
"""
super().__init__(pretrained_model, **kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)

if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
raise ValueError("The model does not have a language model head, please use a model that has one.")

self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)

self._init_weights(**v_head_kwargs)

def _init_weights(self, **kwargs):
Expand Down

0 comments on commit 94e4135

Please sign in to comment.