Skip to content

Commit

Permalink
Refactor model_parallel tests to allow different (device, backend) co…
Browse files Browse the repository at this point in the history
…mbination (pytorch#1667)

Summary:

Refactoring to make model_parallel tests to take more combinations of (device, backend).

1. Won't force device to cuda if backend is NCCL or force device to cpu if backend is gloo. i.e. allow the combinations of (NCCL, CPU) and (Gloo, GPU).
2. Refactor test_parameter_init to test parameter init on (nccl, gpu) and (gloo, cpu) combinations.

Reviewed By: sarckk

Differential Revision: D53149924
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jan 30, 2024
1 parent 2005d7f commit ea9b20d
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 53 deletions.
13 changes: 6 additions & 7 deletions torchrec/distributed/test_utils/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ def __init__(
self.backend = backend
self.local_size = local_size

if backend == "nccl":
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
self.device: torch.device = device
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
self.device: torch.device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(self.device)

torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
self.device: torch.device = torch.device("cpu")
torch.use_deterministic_algorithms(True)
self.pg: Optional[dist.ProcessGroup] = None

def __enter__(self) -> "MultiProcessContext":
Expand Down
19 changes: 14 additions & 5 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

class ModelParallelTestShared(MultiProcessTestBase):
@seed_and_log
def setUp(self) -> None:
def setUp(self, backend: str = "nccl") -> None:
super().setUp()

num_features = 4
Expand Down Expand Up @@ -76,12 +76,11 @@ def setUp(self) -> None:
for feature in table.feature_names
]
}
self.backend = backend
if torch.cuda.is_available():
self.device = torch.device("cuda")
self.backend = "nccl"
else:
self.device = torch.device("cpu")
self.backend = "gloo"

def _test_sharding(
self,
Expand Down Expand Up @@ -122,8 +121,8 @@ def _test_sharding(

@skip_if_asan_class
class ModelParallelBase(ModelParallelTestShared):
def setUp(self) -> None:
super().setUp()
def setUp(self, backend: str = "nccl") -> None:
super().setUp(backend=backend)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
Expand Down Expand Up @@ -166,6 +165,11 @@ def test_sharding_rw(
],
variable_batch_size: bool,
) -> None:
if self.backend == "gloo":
self.skipTest(
"Gloo reduce_scatter_base fallback not supported with async_op=True"
)

sharding_type = ShardingType.ROW_WISE.value
kernel_type = EmbeddingComputeKernel.FUSED.value
assume(
Expand Down Expand Up @@ -367,6 +371,11 @@ def test_sharding_variable_batch(
sharding_type: str,
global_constant_batch: bool,
) -> None:
if self.backend == "gloo":
# error is from FBGEMM, it says CPU even if we are on GPU.
self.skipTest(
"bounds_check_indices on CPU does not support variable length (batch size)"
)
self._test_sharding(
# pyre-ignore[6]
sharders=[
Expand Down
62 changes: 24 additions & 38 deletions torchrec/distributed/test_utils/test_model_parallel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,26 +178,27 @@ def _test_sharded_forward(


class ModelParallelSparseOnlyBase(unittest.TestCase):
def tearDown(self) -> None:
dist.destroy_process_group()

def test_sharding_ebc_as_top_level(self) -> None:
def setUp(self, backend: str = "nccl") -> None:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["NCCL_SOCKET_IFNAME"] = "lo"

self.backend = backend
if torch.cuda.is_available():
curr_device = torch.device("cuda:0")
torch.cuda.set_device(curr_device)
backend = "nccl"
self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device)
else:
curr_device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)
self.device = torch.device("cpu")

dist.init_process_group(backend=self.backend)

def tearDown(self) -> None:
dist.destroy_process_group()

def test_sharding_ebc_as_top_level(self) -> None:
embedding_dim = 128
num_embeddings = 256
ebc = EmbeddingBagCollection(
Expand All @@ -213,27 +214,11 @@ def test_sharding_ebc_as_top_level(self) -> None:
],
)

model = DistributedModelParallel(ebc, device=curr_device)
model = DistributedModelParallel(ebc, device=self.device)

self.assertTrue(isinstance(model.module, ShardedEmbeddingBagCollection))

def test_sharding_fused_ebc_as_top_level(self) -> None:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["NCCL_SOCKET_IFNAME"] = "lo"

if torch.cuda.is_available():
curr_device = torch.device("cuda:0")
torch.cuda.set_device(curr_device)
backend = "nccl"
else:
curr_device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)

embedding_dim = 128
num_embeddings = 256
ebc = FusedEmbeddingBagCollection(
Expand All @@ -251,26 +236,26 @@ def test_sharding_fused_ebc_as_top_level(self) -> None:
optimizer_kwargs={"lr": 0.02},
)

model = DistributedModelParallel(ebc, device=curr_device)
model = DistributedModelParallel(ebc, device=self.device)

self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection))


class ModelParallelStateDictBase(unittest.TestCase):
def setUp(self) -> None:
def setUp(self, backend: str = "nccl") -> None:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["NCCL_SOCKET_IFNAME"] = "lo"

self.backend = backend
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
backend = "nccl"
torch.cuda.set_device(self.device)
else:
self.device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)

num_features = 4
Expand Down Expand Up @@ -377,27 +362,28 @@ def __init__(self, device: str, val: float) -> None:
def reset_parameters(self) -> None:
nn.init.constant_(self.p, self.val)

dist.destroy_process_group()
dist.init_process_group(backend="gloo")

# Check that already allocated parameters are left 'as is'.
cpu_model = MyModel(device="cpu", val=3.2)
unsharded_model = MyModel(device=self.device, val=3.2)
sharded_model = DistributedModelParallel(
cpu_model,
unsharded_model,
device=self.device,
)
sharded_param = next(sharded_model.parameters())
np.testing.assert_array_equal(
np.array([3.2, 3.2, 3.2], dtype=np.float32), sharded_param.detach().numpy()
np.array([3.2, 3.2, 3.2], dtype=np.float32),
sharded_param.detach().cpu().numpy(),
)

# Check that parameters over 'meta' device are allocated and initialized.
meta_model = MyModel(device="meta", val=7.5)
sharded_model = DistributedModelParallel(
meta_model,
device=self.device,
)
sharded_param = next(sharded_model.parameters())
np.testing.assert_array_equal(
np.array([7.5, 7.5, 7.5], dtype=np.float32), sharded_param.detach().numpy()
np.array([7.5, 7.5, 7.5], dtype=np.float32),
sharded_param.detach().cpu().numpy(),
)

def test_meta_device_dmp_state_dict(self) -> None:
Expand Down
11 changes: 8 additions & 3 deletions torchrec/distributed/tests/test_model_parallel_gloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
ModelParallelStateDictBase,
)

# CPU tests for Gloo.


class ModelParallelTestGloo(ModelParallelBase):
pass
def setUp(self, backend: str = "gloo") -> None:
super().setUp(backend=backend)


class ModelParallelStateDictTestGloo(ModelParallelStateDictBase):
pass
def setUp(self, backend: str = "gloo") -> None:
super().setUp(backend=backend)


class ModelParallelSparseOnlyTestGloo(ModelParallelSparseOnlyBase):
pass
def setUp(self, backend: str = "gloo") -> None:
super().setUp(backend=backend)
2 changes: 2 additions & 0 deletions torchrec/distributed/tests/test_model_parallel_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
# LICENSE file in the root directory of this source tree.

from torchrec.distributed.test_utils.test_model_parallel import ModelParallelBase
from torchrec.test_utils import skip_if_no_gpus


@skip_if_no_gpus()
class ModelParallelTestNccl(ModelParallelBase):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
ModelParallelSparseOnlyBase,
ModelParallelStateDictBase,
)
from torchrec.test_utils import skip_if_no_gpus


@skip_if_no_gpus()
class ModelParallelStateDictTestNccl(ModelParallelStateDictBase):
pass


@skip_if_no_gpus()
class ModelParallelSparseOnlyTestNccl(ModelParallelSparseOnlyBase):
pass
7 changes: 7 additions & 0 deletions torchrec/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def wrapper(*args: TParams.args, **kwargs: TParams.kwargs) -> Optional[TReturn]:
return wrapper


def skip_if_no_gpus(cls: TReturn) -> Optional[TReturn]:
if torch.cuda.device_count() == 0:
cls.__unittest_skip__ = True
cls.__unittest_skip_why__ = "Skipping test run since we have no GPUs."
return cls


def skip_if_asan_class(cls: TReturn) -> Optional[TReturn]:
if is_asan_or_tsan():
cls.__unittest_skip__ = True
Expand Down

0 comments on commit ea9b20d

Please sign in to comment.