@@ -55,6 +55,10 @@ class VHeadModelTester(TrlTestCase):
5555 trl_model_class = None
5656 transformers_model_class = None
5757
58+ def setUp (self ):
59+ super ().setUp ()
60+ self .device = "cuda" if torch .cuda .is_available () else "cpu"
61+
5862 def test_value_head (self ):
5963 r"""
6064 Test if the v-head is added to the model successfully
@@ -207,8 +211,8 @@ def test_inference(self):
207211 EXPECTED_OUTPUT_SIZE = 3
208212
209213 for model_name in self .all_model_names :
210- model = self .trl_model_class .from_pretrained (model_name )
211- input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]])
214+ model = self .trl_model_class .from_pretrained (model_name ). to ( self . device )
215+ input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]], device = self . device )
212216 outputs = model (input_ids )
213217
214218 # Check if the outputs are of the right size - here
@@ -250,8 +254,8 @@ def test_generate(self, model_name):
250254 Test if `generate` works for every model
251255 """
252256 generation_config = GenerationConfig (max_new_tokens = 9 )
253- model = self .trl_model_class .from_pretrained (model_name )
254- input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]])
257+ model = self .trl_model_class .from_pretrained (model_name ). to ( self . device )
258+ input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]], device = self . device )
255259
256260 # Just check if the generation works
257261 _ = model .generate (input_ids , generation_config = generation_config )
@@ -263,7 +267,7 @@ def test_transformers_bf16_kwargs(self):
263267 run a dummy forward pass without any issue.
264268 """
265269 for model_name in self .all_model_names :
266- trl_model = self .trl_model_class .from_pretrained (model_name , torch_dtype = torch .bfloat16 )
270+ trl_model = self .trl_model_class .from_pretrained (model_name , torch_dtype = torch .bfloat16 ). to ( self . device )
267271
268272 lm_head_namings = ["lm_head" , "embed_out" , "output_layer" ]
269273
@@ -276,7 +280,7 @@ def test_transformers_bf16_kwargs(self):
276280 if hasattr (trl_model .pretrained_model , lm_head_naming ):
277281 self .assertEqual (getattr (trl_model .pretrained_model , lm_head_naming ).weight .dtype , torch .bfloat16 )
278282
279- dummy_input = torch .LongTensor ([[0 , 1 , 0 , 1 ]])
283+ dummy_input = torch .LongTensor ([[0 , 1 , 0 , 1 ]]). to ( self . device )
280284
281285 # check dummy forward pass works in half precision
282286 _ = trl_model (dummy_input )
@@ -323,9 +327,9 @@ def test_inference(self):
323327 EXPECTED_OUTPUT_SIZE = 3
324328
325329 for model_name in self .all_model_names :
326- model = self .trl_model_class .from_pretrained (model_name )
327- input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]])
328- decoder_input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]])
330+ model = self .trl_model_class .from_pretrained (model_name ). to ( self . device )
331+ input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]], device = self . device )
332+ decoder_input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]], device = self . device )
329333 outputs = model (input_ids , decoder_input_ids = decoder_input_ids )
330334
331335 # Check if the outputs are of the right size - here
@@ -367,9 +371,9 @@ def test_generate(self, model_name):
367371 Test if `generate` works for every model
368372 """
369373 generation_config = GenerationConfig (max_new_tokens = 9 )
370- model = self .trl_model_class .from_pretrained (model_name )
371- input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]])
372- decoder_input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]])
374+ model = self .trl_model_class .from_pretrained (model_name ). to ( self . device )
375+ input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]], device = self . device )
376+ decoder_input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]], device = self . device )
373377
374378 # Just check if the generation works
375379 _ = model .generate (input_ids , decoder_input_ids = decoder_input_ids , generation_config = generation_config )
@@ -400,7 +404,7 @@ def test_transformers_bf16_kwargs(self):
400404 run a dummy forward pass without any issue.
401405 """
402406 for model_name in self .all_model_names :
403- trl_model = self .trl_model_class .from_pretrained (model_name , torch_dtype = torch .bfloat16 )
407+ trl_model = self .trl_model_class .from_pretrained (model_name , torch_dtype = torch .bfloat16 ). to ( self . device )
404408
405409 lm_head_namings = self .trl_model_class .lm_head_namings
406410
@@ -412,7 +416,7 @@ def test_transformers_bf16_kwargs(self):
412416 if hasattr (trl_model .pretrained_model , lm_head_naming ):
413417 self .assertTrue (getattr (trl_model .pretrained_model , lm_head_naming ).weight .dtype == torch .bfloat16 )
414418
415- dummy_input = torch .LongTensor ([[0 , 1 , 0 , 1 ]])
419+ dummy_input = torch .LongTensor ([[0 , 1 , 0 , 1 ]]). to ( self . device )
416420
417421 # check dummy forward pass works in half precision
418422 _ = trl_model (input_ids = dummy_input , decoder_input_ids = dummy_input )
0 commit comments