66Some tests are intended to be run on TPU.
77"""
88
9- import itertools
10-
119import jax
10+ import pytest
1211from absl import logging
1312from absl .testing import absltest , parameterized
1413from jax import numpy as jnp
2322)
2423
2524
25+ # TODO(markblee): Consolidate with utils_test.
2626class GDATest (TestCase ):
27- @parameterized .parameters (
28- itertools .product (
29- ((1 , 1 ), (8 , 1 ), (4 , 2 )), # mesh_shape
30- (1 , 16 ), # per_host_batch_size
31- (DataPartitionType .FULL , DataPartitionType .REPLICATED ), # data_partition
32- )
33- )
34- def test_host_array_to_gda (self , mesh_shape , per_host_batch_size , data_partition ):
27+ def _test_host_array_to_gda (self , mesh_shape , per_host_batch_size , data_partition ):
3528 logging .info (
3629 "mesh_shape=%s per_host_batch_size=%s data_partition=%s" ,
3730 mesh_shape ,
3831 per_host_batch_size ,
3932 data_partition ,
4033 )
4134 if not is_supported_mesh_shape (mesh_shape ):
42- return
35+ pytest . skip ( reason = f"Unsupported { mesh_shape = } " )
4336 devices = mesh_utils .create_device_mesh (mesh_shape )
4437 if data_partition == DataPartitionType .FULL :
4538 global_batch_size = per_host_batch_size * jax .process_count ()
@@ -65,6 +58,23 @@ def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition
6558 self .assertIsInstance (output ["x" ], Tensor )
6659 self .assertSequenceEqual (output ["x" ].shape , (global_batch_size , 8 ))
6760
61+ @parameterized .product (
62+ mesh_shape = [(1 , 1 )],
63+ per_host_batch_size = [1 , 16 ],
64+ data_partition = [DataPartitionType .FULL , DataPartitionType .REPLICATED ],
65+ )
66+ def test_host_array_to_gda_single (self , ** kwargs ):
67+ self ._test_host_array_to_gda (** kwargs )
68+
69+ @parameterized .product (
70+ mesh_shape = [(8 , 1 ), (4 , 2 )],
71+ per_host_batch_size = [1 , 16 ],
72+ data_partition = [DataPartitionType .FULL , DataPartitionType .REPLICATED ],
73+ )
74+ @pytest .mark .for_8_devices
75+ def test_host_array_to_gda_multiple (self , ** kwargs ):
76+ self ._test_host_array_to_gda (** kwargs )
77+
6878
6979if __name__ == "__main__" :
7080 absltest .main ()
0 commit comments