Skip to content

Add v6e special meshes #952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
# google/jax:
# Copyright 2018 Google LLC.
# Licensed under the Apache License, Version 2.0 (the "License").
#
# AI-Hypercomputer/maxtext:
# Copyright 2024 The MaxText Authors.
# Licensed under the Apache License, Version 2.0 (the "License").

"""Common utilities."""

Expand Down Expand Up @@ -1258,10 +1262,78 @@ def register_per_param_settings(
return settings


def _reshape_mesh_to_rings(a: np.ndarray, *, shape: tuple[int, int]) -> np.ndarray:
"""Reshapes device mesh to rings for 64x4 or 32x8 mesh shape.

Adapted from maxtext and made some code simplifications. Reference:
https://github.com/AI-Hypercomputer/maxtext/blob/7f0dcef34f4857476d19b4ca9ceada654246c0b0/MaxText/max_utils.py#L474.

64x4 and 32x8 are non-native mesh sizes on v6e and v5e and require careful arrangement of
devices to achieve good performance.
"""
b = []
if shape == (64, 4):
for i in range(8):
b.append([])
for j in range(8):
a_i = i * 2
a_j = j * 2
# Forms a ring of size 4.
b[i].append([a[a_i, a_j], a[a_i, a_j + 1], a[a_i + 1, a_j + 1], a[a_i + 1, a_j]])
elif shape == (32, 8):
for i in range(8):
b.append([])
for j in range(4):
a_i = i * 2
a_j = j * 4
# Forms a ring of size 8.
b[i].append(
[
a[a_i, a_j],
a[a_i, a_j + 1],
a[a_i, a_j + 2],
a[a_i, a_j + 3],
a[a_i + 1, a_j + 3],
a[a_i + 1, a_j + 2],
a[a_i + 1, a_j + 1],
a[a_i + 1, a_j],
]
)
else:
raise ValueError(f"The target mesh shape {shape} is not implemented.")
return np.reshape(np.array(b), shape)


def _maybe_get_special_mesh(
mesh_shape: MeshShape, *, devices: np.ndarray
) -> Optional[tuple[int, int]]:
"""Checks if any of the special mesh shapes are applicable."""
if int(np.prod(mesh_shape)) != 256:
return None
if getattr(devices[0], "device_kind", None) not in [
"TPU v5e",
"TPU v6e",
"TPU v6 lite",
"TPU v5 lite",
]:
return None

filtered_mesh = tuple(filter(lambda x: x != 1, mesh_shape))
target_shapes = [(64, 4), (32, 8)]
return None if filtered_mesh not in target_shapes else filtered_mesh


def build_standard_mesh(mesh_shape: MeshShape, *, devices: np.ndarray) -> np.ndarray:
logging.info("Building device mesh.")
mesh_shape = infer_mesh_shape(mesh_shape, num_devices=devices.size)
try:
if (shape := _maybe_get_special_mesh(mesh_shape, devices=devices)) is not None:
# If any of the special mesh shapes is applicable, use them.
mesh = mesh_utils.create_device_mesh([16, 16], devices=devices)
mesh = _reshape_mesh_to_rings(mesh, shape=shape)
mesh = mesh.reshape(mesh_shape)
logging.log_first_n(logging.INFO, "Using custom mesh: %s", 1, str(mesh))
return mesh
return mesh_utils.create_device_mesh(mesh_shape, devices=devices)
except NotImplementedError as e:
logging.warning(
Expand Down
33 changes: 31 additions & 2 deletions axlearn/common/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import OrderedDict
from collections.abc import Iterable, Sequence
from typing import Any, NamedTuple, Optional, Union
from unittest import mock

# pylint: disable=no-self-use
import jax
Expand Down Expand Up @@ -1553,11 +1554,32 @@ def test_create_device_mesh_multi_slice_tpuv4(
"logical_mesh": HybridMeshShape(ici_mesh_shape=(1, -1, 16), dcn_mesh_shape=(-1, 1, 1)),
"expected": (2, 16, 16),
},
# Test a case when a special optimized mesh can be used.
{
"logical_mesh": HybridMeshShape(
ici_mesh_shape=(1, 64, 1, 4), dcn_mesh_shape=(-1, 1, 1, 1)
),
"expected": (2, 64, 1, 4),
"is_custom": True,
},
# Test a case when a special optimized mesh can be used.
{
"logical_mesh": HybridMeshShape(ici_mesh_shape=(1, 64, 4), dcn_mesh_shape=(-1, 1, 1)),
"expected": (2, 64, 4),
"is_custom": True,
},
# Test a case when a special optimized mesh can be used.
{
"logical_mesh": HybridMeshShape(ici_mesh_shape=(1, 32, 8), dcn_mesh_shape=(-1, 1, 1)),
"expected": (2, 32, 8),
"is_custom": True,
},
)
def test_create_device_mesh_multi_slice_tpuv5e(
self,
logical_mesh: Union[MeshShape, HybridMeshShape],
expected: Optional[Union[MeshShape, Exception]] = None,
is_custom: bool = False,
):
slice_physical_mesh = (16, 16, 1)
num_slices = 2
Expand All @@ -1570,7 +1592,7 @@ def test_create_device_mesh_multi_slice_tpuv5e(
devices = [
DummyMultiSliceTpuDevice(
platform="tpu",
device_kind="TPU v5litepod",
device_kind="TPU v5 lite",
process_index=(len(coords) * slice_index + ix) // 4,
coords=coord,
slice_index=slice_index,
Expand All @@ -1582,8 +1604,15 @@ def test_create_device_mesh_multi_slice_tpuv5e(
with self.assertRaisesRegex(type(expected), str(expected)):
create_device_mesh(mesh_shape=logical_mesh, devices=devices)
else:
# pylint: disable-next=protected-access
custom_mesh_fn = mock.Mock(wraps=utils._reshape_mesh_to_rings)
with mock.patch.object(utils, "_reshape_mesh_to_rings", custom_mesh_fn):
device_mesh = create_device_mesh(mesh_shape=logical_mesh, devices=devices)
if is_custom:
self.assertEqual(custom_mesh_fn.call_count, num_slices)
else:
custom_mesh_fn.assert_not_called()
# Check that the constructed mesh has the expected shape.
device_mesh = create_device_mesh(mesh_shape=logical_mesh, devices=devices)
self.assertEqual(expected or logical_mesh, device_mesh.shape)

# Check that the sub_mesh along the first non-singleton mesh axis only contains devices
Expand Down
Loading