Skip to content

Commit daf1f25

Browse files
committed
refactor compile part of test tensor parallel
1 parent ce2afaf commit daf1f25

File tree

1 file changed

+51
-26
lines changed

1 file changed

+51
-26
lines changed

tests/tensor_parallel/test_tensor_parallel.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -319,35 +319,49 @@ def _test_model_dense_backward_pass_impl():
319319

320320
torch.distributed.barrier()
321321

322-
def _test_model_dense_generate_impl():
323-
"""Implementation of test_model_generate for distributed execution."""
322+
def _test_model_dense_forward_compile_impl(mode):
323+
"""Implementation for comparing TP and non-TP model outputs with torch.compile."""
324324
model_id = "JackFram/llama-68m"
325-
326-
int(os.environ["RANK"])
327-
int(os.environ["WORLD_SIZE"])
328-
329-
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
330-
torch.distributed.barrier()
331-
332-
model.forward = torch.compile(model.forward)
333-
334-
has_dtensor = 0
335-
for name, parameter in model.named_parameters():
336-
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
337-
has_dtensor = 1
338-
break
339-
340-
assert has_dtensor == 1, "TP model must has DTensor"
341-
325+
326+
torch.manual_seed(0)
327+
342328
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
343329
prompt = "Can I help"
330+
inputs = tokenizer(prompt, return_tensors="pt")
331+
332+
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
333+
torch.distributed.barrier()
334+
if mode == "eval":
335+
model_tp.eval()
336+
else:
337+
model_tp.train()
338+
339+
device = model_tp.device
340+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
341+
model = model.to(device)
344342

345-
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
346-
outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static")
347-
348-
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
349-
assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'"
343+
if mode == "eval":
344+
model.eval()
345+
else:
346+
model.train()
347+
348+
# Compile both models
349+
model.forward = torch.compile(model.forward)
350+
model_tp.forward = torch.compile(model_tp.forward)
351+
352+
input_ids = inputs.input_ids.to(device)
353+
354+
with torch.no_grad():
355+
outputs = model(input_ids)
356+
logits = outputs.logits
357+
358+
outputs_tp = model_tp(input_ids)
359+
logits_tp = outputs_tp.logits
350360

361+
assert torch.allclose(
362+
logits, logits_tp, atol=1e-5, rtol=1e-5
363+
), f"TP and non-TP model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
364+
351365
torch.distributed.barrier()
352366

353367

@@ -400,13 +414,24 @@ def test_model_dense_backward_pass(self):
400414
init_distributed(tp=self.nproc_per_node)(_test_model_dense_backward_pass_impl)()
401415

402416
@require_torch_multi_accelerator
403-
def test_model_dense_generate(self):
417+
def test_model_dense_forward_compile_eval(self):
418+
"""Test that TP and non-TP models produce the same outputs with torch.compile in eval mode."""
419+
if self.nproc_per_node is None:
420+
self.skipTest("nproc_per_node not set")
421+
if backend_device_count(torch_device) < self.nproc_per_node:
422+
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
423+
424+
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("eval")
425+
426+
@require_torch_multi_accelerator
427+
def test_model_dense_forward_compile_train(self):
428+
"""Test that TP and non-TP models produce the same outputs with torch.compile in train mode."""
404429
if self.nproc_per_node is None:
405430
self.skipTest("nproc_per_node not set")
406431
if backend_device_count(torch_device) < self.nproc_per_node:
407432
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
408433

409-
init_distributed(tp=self.nproc_per_node)(_test_model_dense_generate_impl)()
434+
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("train")
410435

411436
@require_huggingface_hub_greater_or_equal("0.31.4")
412437
@require_torch_multi_accelerator

0 commit comments

Comments
 (0)