Skip to content
241 changes: 169 additions & 72 deletions tests/tensor_parallel/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
# Run all tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py
# Run specific config: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc"
# Run multiple configs: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc or 4Proc"
# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_forward

# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_dense_forward_train
# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc -k "forward"
import os
import tempfile
import warnings

from safetensors import safe_open

from transformers import AutoModelForCausalLM, AutoTokenizer, is_torch_available
from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights
from transformers.integrations.tensor_parallel import get_packed_weights, get_tensor_shard, repack_weights
from transformers.testing_utils import (
TestCasePlus,
backend_device_count,
Expand All @@ -37,6 +37,7 @@

if is_torch_available():
import torch
import torch.distributed as dist
import torch.multiprocessing as mp


Expand All @@ -53,14 +54,14 @@ def setup_dist_env(rank, world_size, port):

if torch.cuda.is_available():
torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
else:
torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size)
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)

func(rank, *func_args, **func_kwargs)

torch.distributed.barrier()
torch.distributed.destroy_process_group()
dist.barrier()
dist.destroy_process_group()


def init_distributed(tp: int):
Expand Down Expand Up @@ -211,95 +212,169 @@ def test_tp_plan_none_handling(self):


# ====== TEST FUNCTIONS ======
def _test_model_forward_impl(rank):
"""Implementation of test_model_forward for distributed execution."""
def _test_model_dense_forward_impl(rank, mode):
"""Implementation for comparing TP and non-TP model outputs."""
model_id = "JackFram/llama-68m"

int(os.environ["RANK"])
int(os.environ["WORLD_SIZE"])
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
torch.distributed.barrier()

has_dtensor = 0
for name, parameter in model.named_parameters():
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
has_dtensor = 1
break

assert has_dtensor == 1, "TP model must has DTensor"
# Ensure same random seed for reproducibility
torch.manual_seed(0)

# Load tokenizer and prepare inputs - same for both models
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt")

# Load TP model first to determine device
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
dist.barrier()
if mode == "eval":
model_tp.eval()
else:
model_tp.train()

# Load non-TP model and move to same device as TP model
device = model_tp.device
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
model = model.to(device)
Comment on lines +235 to +238
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah no device map auto because you always run with torchrun?

Copy link
Member Author

@3outeille 3outeille Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I run with pytest and this will call torch.mp.spawn which is pretty much like running with torchrun. I dont want to set device_map=auto because I am already setting manually the model to a specific device right after


if mode == "eval":
model.eval()
else:
model.train()

# Prepare inputs on the same device
input_ids = inputs.input_ids.to(device)

# Run forward pass on both models
with torch.no_grad():
# Non-TP model output
outputs = model(input_ids)
logits = outputs.logits

# TP model output
outputs_tp = model_tp(input_ids)
logits_tp = outputs_tp.logits

inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
outputs = model(inputs)
# Compare outputs - they should match
assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
f"TP and non-TP model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
)

next_token_logits = outputs[0][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
response = tokenizer.decode(next_token)
assert response == "with"
print("response:", response)
torch.distributed.barrier()
dist.barrier()


def _test_model_backward_pass_impl(rank):
"""Implementation of test_model_backward_pass for distributed execution."""
def _test_model_dense_backward_pass_impl(rank):
"""Implementation for comparing TP and non-TP model backward passes."""
model_id = "JackFram/llama-68m"

model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32, tp_plan="auto")
torch.distributed.barrier()
torch.manual_seed(0)

# Dummy forward and backward pass
# Note that loss.backward() will fail if there is a bug in the TP implementation
inputs = torch.randint(0, model.config.vocab_size, (2, 10), device=model.device)
labels = torch.randint(0, model.config.vocab_size, (2, 10), device=model.device)
loss = model(inputs, labels=labels).loss
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32, tp_plan="auto")
dist.barrier()
model_tp.train()

device = model_tp.device
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
model = model.to(device)
model.train()

batch_size, seq_length = 2, 10
torch.manual_seed(42) # Different seed for inputs to ensure they're deterministic
input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)
labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)

outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()

torch.distributed.barrier()
outputs_tp = model_tp(input_ids, labels=labels)
loss_tp = outputs_tp.loss
loss_tp.backward()

assert torch.allclose(loss, loss_tp, atol=1e-5, rtol=1e-5), (
f"TP and non-TP model losses differ. Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}"
)

def _test_model_generate_impl(rank):
"""Implementation of test_model_generate for distributed execution."""
model_id = "JackFram/llama-68m"
# Compare gradients for matching parameters
# Note: TP model may have sharded parameters (DTensors), so we slice the reference gradient to match
for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()):
if param.grad is not None and param_tp.grad is not None:
grad = param.grad
grad_tp = param_tp.grad

int(os.environ["RANK"])
int(os.environ["WORLD_SIZE"])
if isinstance(param_tp.data, dist.tensor.DTensor):
placement = param_tp.data.placements[0]
if hasattr(placement, "dim") and placement.dim is not None:
grad_shard = get_tensor_shard(grad, grad, param_tp.data.device_mesh, rank, placement.dim)
else:
grad_shard = grad
else:
grad_shard = grad

model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
torch.distributed.barrier()
grad_tp_local = grad_tp.to_local() if isinstance(grad_tp, dist.tensor.DTensor) else grad_tp

model.forward = torch.compile(model.forward)
assert torch.allclose(grad_shard.cpu(), grad_tp_local.cpu(), atol=1e-5, rtol=1e-5), (
f"Gradients differ for parameter {name}. Max diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().max().item()} | Min diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().min().item()}"
)

has_dtensor = 0
for name, parameter in model.named_parameters():
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
has_dtensor = 1
break
dist.barrier()

assert has_dtensor == 1, "TP model must has DTensor"

tokenizer = AutoTokenizer.from_pretrained(model_id)
def _test_model_dense_forward_compile_impl(rank, mode):
"""Implementation for comparing TP and non-TP model outputs with torch.compile."""
model_id = "JackFram/llama-68m"

torch.manual_seed(0)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt")

inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static")
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
dist.barrier()
if mode == "eval":
model_tp.eval()
else:
model_tp.train()

output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'"
device = model_tp.device
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
model = model.to(device)

torch.distributed.barrier()
if mode == "eval":
model.eval()
else:
model.train()

# Compile both models
model.forward = torch.compile(model.forward)
model_tp.forward = torch.compile(model_tp.forward)

input_ids = inputs.input_ids.to(device)

with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits

outputs_tp = model_tp(input_ids)
logits_tp = outputs_tp.logits

assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
f"TP and non-TP model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
)

dist.barrier()

def _test_model_save_impl(rank, tmp_dir, is_torchrun):

def _test_model_dense_save_impl(rank, tmp_dir):
"""Implementation of test_model_save for distributed execution."""
model_id = "JackFram/llama-68m"
kwargs = {}

if os.environ.get("RANK", None) is not None:
kwargs["tp_plan"] = "auto"
if dist.is_initialized():
kwargs = {"tp_plan": "auto"}
result_dir = f"{tmp_dir}/tp"
else:
kwargs = {}
result_dir = f"{tmp_dir}/nontp"

model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
Expand All @@ -312,46 +387,68 @@ class TestTensorParallelBase(TestCasePlus):
nproc_per_node = None

@require_torch_multi_accelerator
def test_model_forward(self):
def test_model_dense_forward_eval(self):
"""Test that TP and non-TP models produce the same outputs in eval mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_impl)("eval")

@require_torch_multi_accelerator
def test_model_dense_forward_train(self):
"""Test that TP and non-TP models produce the same outputs in train mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_impl)("train")

@require_torch_multi_accelerator
def test_model_dense_backward_pass(self):
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

init_distributed(tp=self.nproc_per_node)(_test_model_forward_impl)()
init_distributed(tp=self.nproc_per_node)(_test_model_dense_backward_pass_impl)()

@require_torch_multi_accelerator
def test_model_backward_pass(self):
def test_model_dense_forward_compile_eval(self):
"""Test that TP and non-TP models produce the same outputs with torch.compile in eval mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

init_distributed(tp=self.nproc_per_node)(_test_model_backward_pass_impl)()
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("eval")

@require_torch_multi_accelerator
def test_model_generate(self):
def test_model_dense_forward_compile_train(self):
"""Test that TP and non-TP models produce the same outputs with torch.compile in train mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

init_distributed(tp=self.nproc_per_node)(_test_model_generate_impl)()
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("train")

@require_huggingface_hub_greater_or_equal("0.31.4")
@require_torch_multi_accelerator
def test_model_save(self):
def test_model_dense_save(self):
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

with tempfile.TemporaryDirectory() as tmp_dir:
# First run with TP (distributed)
init_distributed(tp=self.nproc_per_node)(_test_model_save_impl)(tmp_dir, True)
init_distributed(tp=self.nproc_per_node)(_test_model_dense_save_impl)(tmp_dir)

# Then run without TP (non-distributed)
_test_model_save_impl(0, tmp_dir, False)
_test_model_dense_save_impl(0, tmp_dir)

non_tp_model_path = os.path.join(tmp_dir, "nontp")
tp_model_path = os.path.join(tmp_dir, "tp")
Expand Down