Skip to content

Commit

Permalink
[1/N] Introduce init_device_mesh() (pytorch#107254)
Browse files Browse the repository at this point in the history
This PR introduces init_device_mesh() as an API to standardize UX device_mesh initialization.

The functionality of slicing out a submesh from a given mesh would come in later PRs.
Pull Request resolved: pytorch#107254
Approved by: https://github.com/wanchaol
  • Loading branch information
wz337 authored and pytorchmergebot committed Aug 21, 2023
1 parent 5ddb8ef commit f5d1df3
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 1 deletion.
25 changes: 24 additions & 1 deletion test/distributed/_tensor/test_device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
mesh_broadcast,
mesh_scatter,
)
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed._tensor.placement_types import Shard

from torch.distributed.distributed_c10d import (
Expand Down Expand Up @@ -157,6 +157,29 @@ def test_device_mesh_hash(self):
self.assertNotEqual(hash(mesh2), hash(mesh3))


class InitDeviceMeshTest(DTensorTestBase):
@property
def world_size(self):
return 8

@with_comms
def test_init_device_mesh(self):
mesh_shape = (2, 4)
ref_mesh = DeviceMesh(self.device_type, torch.arange(8).view(mesh_shape))

# test init_device_mesh with mesh_dim_names
mesh_dim_names = ("DP", "TP")
two_d_mesh = init_device_mesh(
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
self.assertEqual(two_d_mesh, ref_mesh)
self.assertEqual(two_d_mesh.mesh_dim_names, mesh_dim_names)

# test init_device_mesh without mesh_dim_names
two_d_mesh = init_device_mesh(self.device_type, mesh_shape)
self.assertEqual(two_d_mesh, ref_mesh)


class DeviceMeshCollectiveTest(DTensorTestBase):
@property
def world_size(self):
Expand Down
52 changes: 52 additions & 0 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
import math
from typing import List, Optional, Tuple, TYPE_CHECKING, Union

import torch
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
device_type: str,
mesh: Union[torch.Tensor, "ArrayLike"],
*,
mesh_dim_names: Optional[Tuple[str, ...]] = None,
_init_process_groups: bool = True,
_validate_mesh: bool = True,
) -> None:
Expand All @@ -111,6 +113,7 @@ def __init__(
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, dtype=torch.int)
)
self.mesh_dim_names = mesh_dim_names
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
# process (we need to know if the current global rank is in the mesh or not)
Expand Down Expand Up @@ -269,3 +272,52 @@ def get_coordinate(self) -> Optional[List[int]]:
dimensions of the mesh. If this rank is not part of the mesh, return None.
"""
return self._coordinate_on_dim if self._coordinate_on_dim else None


def init_device_mesh(
device_type: str,
mesh_shape: Tuple[int, ...],
*,
mesh_dim_names: Optional[Tuple[str, ...]] = None,
) -> DeviceMesh:
"""
Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape)
and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is
labeled as mesh_dim_names[i].
Args:
device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array
that describes the layout of devices.
Kwargs:
mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension
of the multi-dimensional array that describes the layout of devices. Its length must match the length
of `mesh_shape`.
Returns:
A :class:`DeviceMesh` object
.. note: If no process group is found, init_device_mesh will initialize distributed process group/groups
behind the scene, which are requried for distributed communications.
Example:
>>> # xdoctest: +SKIP
>>> from torch.distributed._tensor.device_mesh import init_device_mesh
>>>
>>> two_d_mesh = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
"""
if mesh_dim_names is not None and len(mesh_shape) != len(mesh_dim_names):
raise RuntimeError(
f"Please provide a mesh_dim_name to each mesh_dim! Found {len(mesh_dim_names)} instead of {len(mesh_shape)}."
)

mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)
device_mesh = DeviceMesh(
device_type=device_type,
mesh=mesh,
mesh_dim_names=mesh_dim_names,
)

return device_mesh

0 comments on commit f5d1df3

Please sign in to comment.