@@ -319,35 +319,49 @@ def _test_model_dense_backward_pass_impl():
319319
320320 torch .distributed .barrier ()
321321
322- def _test_model_dense_generate_impl ( ):
323- """Implementation of test_model_generate for distributed execution ."""
322+ def _test_model_dense_forward_compile_impl ( mode ):
323+ """Implementation for comparing TP and non-TP model outputs with torch.compile ."""
324324 model_id = "JackFram/llama-68m"
325-
326- int (os .environ ["RANK" ])
327- int (os .environ ["WORLD_SIZE" ])
328-
329- model = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" , tp_plan = "auto" )
330- torch .distributed .barrier ()
331-
332- model .forward = torch .compile (model .forward )
333-
334- has_dtensor = 0
335- for name , parameter in model .named_parameters ():
336- if isinstance (parameter .data , torch .distributed .tensor .DTensor ):
337- has_dtensor = 1
338- break
339-
340- assert has_dtensor == 1 , "TP model must has DTensor"
341-
325+
326+ torch .manual_seed (0 )
327+
342328 tokenizer = AutoTokenizer .from_pretrained (model_id , use_fast = False )
343329 prompt = "Can I help"
330+ inputs = tokenizer (prompt , return_tensors = "pt" )
331+
332+ model_tp = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" , tp_plan = "auto" )
333+ torch .distributed .barrier ()
334+ if mode == "eval" :
335+ model_tp .eval ()
336+ else :
337+ model_tp .train ()
338+
339+ device = model_tp .device
340+ model = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" )
341+ model = model .to (device )
344342
345- inputs = tokenizer (prompt , return_tensors = "pt" ).input_ids .to (model .device )
346- outputs = model .generate (inputs , max_new_tokens = 10 , cache_implementation = "static" )
347-
348- output_text = tokenizer .batch_decode (outputs , skip_special_tokens = True )
349- assert output_text [0 ].startswith (prompt ), f"Expected output to start with '{ prompt } ', got '{ output_text [0 ]} '"
343+ if mode == "eval" :
344+ model .eval ()
345+ else :
346+ model .train ()
347+
348+ # Compile both models
349+ model .forward = torch .compile (model .forward )
350+ model_tp .forward = torch .compile (model_tp .forward )
351+
352+ input_ids = inputs .input_ids .to (device )
353+
354+ with torch .no_grad ():
355+ outputs = model (input_ids )
356+ logits = outputs .logits
357+
358+ outputs_tp = model_tp (input_ids )
359+ logits_tp = outputs_tp .logits
350360
361+ assert torch .allclose (
362+ logits , logits_tp , atol = 1e-5 , rtol = 1e-5
363+ ), f"TP and non-TP model outputs differ. Max diff: { (logits - logits_tp ).abs ().max ().item ()} | Min diff: { (logits - logits_tp ).abs ().min ().item ()} "
364+
351365 torch .distributed .barrier ()
352366
353367
@@ -400,13 +414,24 @@ def test_model_dense_backward_pass(self):
400414 init_distributed (tp = self .nproc_per_node )(_test_model_dense_backward_pass_impl )()
401415
402416 @require_torch_multi_accelerator
403- def test_model_dense_generate (self ):
417+ def test_model_dense_forward_compile_eval (self ):
418+ """Test that TP and non-TP models produce the same outputs with torch.compile in eval mode."""
419+ if self .nproc_per_node is None :
420+ self .skipTest ("nproc_per_node not set" )
421+ if backend_device_count (torch_device ) < self .nproc_per_node :
422+ self .skipTest (f"Need at least { self .nproc_per_node } devices, have { backend_device_count (torch_device )} " )
423+
424+ init_distributed (tp = self .nproc_per_node )(_test_model_dense_forward_compile_impl )("eval" )
425+
426+ @require_torch_multi_accelerator
427+ def test_model_dense_forward_compile_train (self ):
428+ """Test that TP and non-TP models produce the same outputs with torch.compile in train mode."""
404429 if self .nproc_per_node is None :
405430 self .skipTest ("nproc_per_node not set" )
406431 if backend_device_count (torch_device ) < self .nproc_per_node :
407432 self .skipTest (f"Need at least { self .nproc_per_node } devices, have { backend_device_count (torch_device )} " )
408433
409- init_distributed (tp = self .nproc_per_node )(_test_model_dense_generate_impl )( )
434+ init_distributed (tp = self .nproc_per_node )(_test_model_dense_forward_compile_impl )( "train" )
410435
411436 @require_huggingface_hub_greater_or_equal ("0.31.4" )
412437 @require_torch_multi_accelerator
0 commit comments