Skip to content

Commit 336c75d

Browse files
author
Mark Lee
authored
Supports arbitrary uniform partitioning in host-global array conversions. (apple#1029)
* Allows specifying PartitionSpec to host_to_global_device_array. * Generalizes to arbitrary uniform partitioning. * Addresses comments and adds mixed shape test.
1 parent 0881412 commit 336c75d

File tree

5 files changed

+500
-124
lines changed

5 files changed

+500
-124
lines changed

axlearn/common/gda_test.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
Some tests are intended to be run on TPU.
77
"""
88

9-
import itertools
10-
119
import jax
10+
import pytest
1211
from absl import logging
1312
from absl.testing import absltest, parameterized
1413
from jax import numpy as jnp
@@ -23,23 +22,17 @@
2322
)
2423

2524

25+
# TODO(markblee): Consolidate with utils_test.
2626
class GDATest(TestCase):
27-
@parameterized.parameters(
28-
itertools.product(
29-
((1, 1), (8, 1), (4, 2)), # mesh_shape
30-
(1, 16), # per_host_batch_size
31-
(DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition
32-
)
33-
)
34-
def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition):
27+
def _test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition):
3528
logging.info(
3629
"mesh_shape=%s per_host_batch_size=%s data_partition=%s",
3730
mesh_shape,
3831
per_host_batch_size,
3932
data_partition,
4033
)
4134
if not is_supported_mesh_shape(mesh_shape):
42-
return
35+
pytest.skip(reason=f"Unsupported {mesh_shape=}")
4336
devices = mesh_utils.create_device_mesh(mesh_shape)
4437
if data_partition == DataPartitionType.FULL:
4538
global_batch_size = per_host_batch_size * jax.process_count()
@@ -65,6 +58,23 @@ def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition
6558
self.assertIsInstance(output["x"], Tensor)
6659
self.assertSequenceEqual(output["x"].shape, (global_batch_size, 8))
6760

61+
@parameterized.product(
62+
mesh_shape=[(1, 1)],
63+
per_host_batch_size=[1, 16],
64+
data_partition=[DataPartitionType.FULL, DataPartitionType.REPLICATED],
65+
)
66+
def test_host_array_to_gda_single(self, **kwargs):
67+
self._test_host_array_to_gda(**kwargs)
68+
69+
@parameterized.product(
70+
mesh_shape=[(8, 1), (4, 2)],
71+
per_host_batch_size=[1, 16],
72+
data_partition=[DataPartitionType.FULL, DataPartitionType.REPLICATED],
73+
)
74+
@pytest.mark.for_8_devices
75+
def test_host_array_to_gda_multiple(self, **kwargs):
76+
self._test_host_array_to_gda(**kwargs)
77+
6878

6979
if __name__ == "__main__":
7080
absltest.main()

0 commit comments

Comments
 (0)