Skip to content

Commit 515e9eb

Browse files
authored
[CI] Modify tests to handle device allocation for models (#3962)
1 parent 26442ab commit 515e9eb

File tree

6 files changed

+45
-36
lines changed

6 files changed

+45
-36
lines changed

tests/test_modeling_geometric_mixture_wrapper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,17 @@ class TestGeometricMixtureWrapper(TrlTestCase):
2525
def setUp(self):
2626
super().setUp()
2727
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
28-
self.model = AutoModelForCausalLM.from_pretrained(model_id)
29-
self.ref_model = create_reference_model(self.model)
28+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
29+
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
30+
self.ref_model = create_reference_model(self.model).to(self.device)
3031
self.generation_config = GenerationConfig.from_pretrained(model_id)
3132
self.mixture_coef = 0.5
3233
self.wrapper = GeometricMixtureWrapper(
3334
self.model, self.ref_model, self.generation_config, mixture_coef=self.mixture_coef
3435
)
3536

3637
def test_forward(self):
37-
input_ids = torch.tensor([[1, 2, 3, 4, 5]])
38+
input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device)
3839
attention_mask = torch.ones_like(input_ids)
3940

4041
output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask)
@@ -44,7 +45,7 @@ def test_forward(self):
4445
self.assertEqual(output.logits.shape, (1, 5, self.model.config.vocab_size))
4546

4647
def test_mixture_coefficient(self):
47-
input_ids = torch.tensor([[1, 2, 3, 4, 5]])
48+
input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device)
4849
attention_mask = torch.ones_like(input_ids)
4950

5051
with torch.no_grad():
@@ -59,7 +60,7 @@ def test_mixture_coefficient(self):
5960
self.assertTrue(torch.allclose(wrapper_output.logits, expected_logits, atol=1e-5))
6061

6162
def test_prepare_inputs_for_generation(self):
62-
input_ids = torch.tensor([[1, 2, 3, 4, 5]])
63+
input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device)
6364
attention_mask = torch.ones_like(input_ids)
6465

6566
inputs = self.wrapper.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask, use_cache=True)

tests/test_modeling_value_head.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class VHeadModelTester(TrlTestCase):
5555
trl_model_class = None
5656
transformers_model_class = None
5757

58+
def setUp(self):
59+
super().setUp()
60+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
61+
5862
def test_value_head(self):
5963
r"""
6064
Test if the v-head is added to the model successfully
@@ -207,8 +211,8 @@ def test_inference(self):
207211
EXPECTED_OUTPUT_SIZE = 3
208212

209213
for model_name in self.all_model_names:
210-
model = self.trl_model_class.from_pretrained(model_name)
211-
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
214+
model = self.trl_model_class.from_pretrained(model_name).to(self.device)
215+
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
212216
outputs = model(input_ids)
213217

214218
# Check if the outputs are of the right size - here
@@ -250,8 +254,8 @@ def test_generate(self, model_name):
250254
Test if `generate` works for every model
251255
"""
252256
generation_config = GenerationConfig(max_new_tokens=9)
253-
model = self.trl_model_class.from_pretrained(model_name)
254-
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
257+
model = self.trl_model_class.from_pretrained(model_name).to(self.device)
258+
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
255259

256260
# Just check if the generation works
257261
_ = model.generate(input_ids, generation_config=generation_config)
@@ -263,7 +267,7 @@ def test_transformers_bf16_kwargs(self):
263267
run a dummy forward pass without any issue.
264268
"""
265269
for model_name in self.all_model_names:
266-
trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16)
270+
trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(self.device)
267271

268272
lm_head_namings = ["lm_head", "embed_out", "output_layer"]
269273

@@ -276,7 +280,7 @@ def test_transformers_bf16_kwargs(self):
276280
if hasattr(trl_model.pretrained_model, lm_head_naming):
277281
self.assertEqual(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype, torch.bfloat16)
278282

279-
dummy_input = torch.LongTensor([[0, 1, 0, 1]])
283+
dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(self.device)
280284

281285
# check dummy forward pass works in half precision
282286
_ = trl_model(dummy_input)
@@ -323,9 +327,9 @@ def test_inference(self):
323327
EXPECTED_OUTPUT_SIZE = 3
324328

325329
for model_name in self.all_model_names:
326-
model = self.trl_model_class.from_pretrained(model_name)
327-
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
328-
decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
330+
model = self.trl_model_class.from_pretrained(model_name).to(self.device)
331+
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
332+
decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
329333
outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
330334

331335
# Check if the outputs are of the right size - here
@@ -367,9 +371,9 @@ def test_generate(self, model_name):
367371
Test if `generate` works for every model
368372
"""
369373
generation_config = GenerationConfig(max_new_tokens=9)
370-
model = self.trl_model_class.from_pretrained(model_name)
371-
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
372-
decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
374+
model = self.trl_model_class.from_pretrained(model_name).to(self.device)
375+
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
376+
decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
373377

374378
# Just check if the generation works
375379
_ = model.generate(input_ids, decoder_input_ids=decoder_input_ids, generation_config=generation_config)
@@ -400,7 +404,7 @@ def test_transformers_bf16_kwargs(self):
400404
run a dummy forward pass without any issue.
401405
"""
402406
for model_name in self.all_model_names:
403-
trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16)
407+
trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(self.device)
404408

405409
lm_head_namings = self.trl_model_class.lm_head_namings
406410

@@ -412,7 +416,7 @@ def test_transformers_bf16_kwargs(self):
412416
if hasattr(trl_model.pretrained_model, lm_head_naming):
413417
self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16)
414418

415-
dummy_input = torch.LongTensor([[0, 1, 0, 1]])
419+
dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(self.device)
416420

417421
# check dummy forward pass works in half precision
418422
_ = trl_model(input_ids=dummy_input, decoder_input_ids=dummy_input)

tests/test_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,8 @@ def setUp(self):
329329
super().setUp()
330330
# Initialize the tokenizer
331331
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
332-
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
332+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
333+
self.model = AutoModelForCausalLM.from_pretrained(self.model_id).to(self.device)
333334
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
334335

335336
self.generation_config = GenerationConfig(
@@ -350,7 +351,7 @@ def test_mini_batch_generation(self):
350351
self.tokenizer.apply_chat_template(example[:-1], add_generation_prompt=True, tokenize=False)
351352
for example in self.examples
352353
]
353-
queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"]
354+
queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"].to(self.device)
354355
bs, context_length = queries.shape
355356

356357
query_responses, logits = batch_generation(
@@ -369,7 +370,7 @@ def test_single_batch_generation(self):
369370
self.tokenizer.apply_chat_template(example[:-1], add_generation_prompt=True, tokenize=False)
370371
for example in self.examples
371372
]
372-
queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"]
373+
queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"].to(self.device)
373374
bs, context_length = queries.shape
374375

375376
query_responses, logits = batch_generation(

trl/trainer/dpo_trainer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,10 +1333,11 @@ def _compute_loss_liger(
13331333
model_kwargs["attention_mask"] = attention_mask
13341334

13351335
# Get the base model outputs (before LM head)
1336-
if hasattr(unwrapped_model, "get_decoder"):
1336+
if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None:
13371337
base_model = unwrapped_model.get_decoder()
13381338
else:
1339-
base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model)
1339+
base_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name)
1340+
base_model = getattr(unwrapped_model, base_attr, unwrapped_model)
13401341

13411342
outputs = base_model(
13421343
input_ids,
@@ -1349,12 +1350,11 @@ def _compute_loss_liger(
13491350
ref_hidden_states = None
13501351
if not self.reference_free and self.ref_model is not None:
13511352
unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
1352-
if hasattr(unwrapped_ref_model, "get_decoder"):
1353+
if hasattr(unwrapped_ref_model, "get_decoder") and unwrapped_ref_model.get_decoder() is not None:
13531354
ref_base_model = unwrapped_ref_model.get_decoder()
13541355
else:
1355-
ref_base_model = getattr(
1356-
unwrapped_ref_model, self.args.base_model_attribute_name, unwrapped_ref_model
1357-
)
1356+
ref_attr = getattr(unwrapped_ref_model, "base_model_prefix", self.args.base_model_attribute_name)
1357+
ref_base_model = getattr(unwrapped_ref_model, ref_attr, unwrapped_ref_model)
13581358

13591359
ref_outputs = ref_base_model(
13601360
input_ids,
@@ -1363,10 +1363,11 @@ def _compute_loss_liger(
13631363
)
13641364
ref_hidden_states = ref_outputs.last_hidden_state[:, :-1]
13651365
elif not self.reference_free:
1366-
if hasattr(unwrapped_model, "get_decoder"):
1366+
if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None:
13671367
ref_base_model = unwrapped_model.get_decoder()
13681368
else:
1369-
ref_base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model)
1369+
ref_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name)
1370+
ref_base_model = getattr(unwrapped_model, ref_attr, unwrapped_model)
13701371
with self.null_ref_context():
13711372
ref_outputs = ref_base_model(
13721373
input_ids,

trl/trainer/gkd_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
243243
if self.use_liger_gkd_loss:
244244
# Forward only through the base models (avoid lm_head to save memory)
245245
unwrapped_student = self.accelerator.unwrap_model(model)
246-
if hasattr(unwrapped_student, "get_decoder"):
246+
if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None:
247247
base_student = unwrapped_student.get_decoder()
248248
else:
249249
base_student = getattr(
@@ -259,7 +259,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
259259

260260
self.teacher_model.eval()
261261
unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model)
262-
if hasattr(unwrapped_teacher, "get_decoder"):
262+
if hasattr(unwrapped_teacher, "get_decoder") and unwrapped_teacher.get_decoder() is not None:
263263
base_teacher = unwrapped_teacher.get_decoder()
264264
else:
265265
base_teacher = getattr(

trl/trainer/kto_trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,10 +1271,11 @@ def _compute_loss_liger(self, model, batch):
12711271
)
12721272
else:
12731273
# skip the lm head and get the last hidden state
1274-
if hasattr(model, "get_decoder"):
1274+
if hasattr(model, "get_decoder") and model.get_decoder() is not None:
12751275
base_model = model.get_decoder()
12761276
else:
1277-
base_model = getattr(model, self.args.base_model_attribute_name)
1277+
base_attr = getattr(model, "base_model_prefix", self.args.base_model_attribute_name)
1278+
base_model = getattr(model, base_attr, model)
12781279
outputs = base_model(
12791280
batch["completion_input_ids"],
12801281
attention_mask=batch["completion_attention_mask"],
@@ -1283,10 +1284,11 @@ def _compute_loss_liger(self, model, batch):
12831284
)
12841285

12851286
# reference model
1286-
if hasattr(self.ref_model, "get_decoder"):
1287+
if hasattr(self.ref_model, "get_decoder") and self.ref_model.get_decoder() is not None:
12871288
ref_base_model = self.ref_model.get_decoder()
12881289
else:
1289-
ref_base_model = getattr(self.ref_model, self.args.base_model_attribute_name)
1290+
ref_attr = getattr(self.ref_model, "base_model_prefix", self.args.base_model_attribute_name)
1291+
ref_base_model = getattr(self.ref_model, ref_attr, self.ref_model)
12901292
ref_outputs = ref_base_model(
12911293
batch["completion_input_ids"],
12921294
attention_mask=batch["completion_attention_mask"],

0 commit comments

Comments
 (0)