-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathsharding_test.py
345 lines (292 loc) · 10.8 KB
/
sharding_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as PS
from jax.sharding import Mesh, NamedSharding
from absl.testing import absltest, parameterized
from scalax.sharding import (
FSDPShardingRule, TreePathShardingRule, PolicyShardingRule,
MeshShardingHelper, with_sharding_annotation, with_sharding_constraint,
get_global_mesh_helper, get_global_mesh
)
class FSDPShardingRuleTest(parameterized.TestCase):
@parameterized.parameters(
(4, 1024),
(8, 2048),
(2, 512),
)
def test_sharding_rule(self, fsdp_axis_size, min_fsdp_size):
pytree = {
'scaler': jnp.ones([]),
'small_vector': jnp.ones(min_fsdp_size - fsdp_axis_size),
'large_vector': jnp.ones(min_fsdp_size * 2),
'small_matrix': jnp.ones((16, 16)),
'large_matrix': jnp.ones((min_fsdp_size, min_fsdp_size)),
'weird_matrix': jnp.ones((min_fsdp_size + 1, min_fsdp_size)),
'really_weird_matrix': jnp.ones((min_fsdp_size + 1, min_fsdp_size + 1)),
}
sharding_rule = FSDPShardingRule(
fsdp_axis_name='fsdp',
fsdp_axis_size=fsdp_axis_size,
min_fsdp_size=min_fsdp_size,
)
matched_partition_specs = sharding_rule.apply(pytree)
expected_partition_specs = {
'scaler': PS(),
'small_vector': PS(),
'large_vector': PS('fsdp'),
'small_matrix': PS(),
'large_matrix': PS('fsdp', None),
'weird_matrix': PS(None, 'fsdp'),
'really_weird_matrix': PS(),
}
self.assertEqual(matched_partition_specs, expected_partition_specs)
class TreePathShardingRuleTest(parameterized.TestCase):
def test_tree_path_sharding_rule(self):
pytree = {
'a': jnp.ones((16, 16)),
'b': {
'c': jnp.ones((16, 16)),
'd': jnp.ones((16, 16)),
},
'e': jnp.ones([]),
}
sharding_rule = TreePathShardingRule(
('a', PS('x', 'y')),
('b/c', PS('x')),
('b/d', PS('y')),
)
matched_partition_specs = sharding_rule.apply(pytree)
expected_partition_specs = {
'a': PS('x', 'y'),
'b': {
'c': PS('x'),
'd': PS('y'),
},
'e': PS(),
}
self.assertEqual(matched_partition_specs, expected_partition_specs)
def test_tree_path_sharding_rule_strict(self):
""" Test that the sharding rule is strict and raises an error if a
leaf is not found in the rule patterns.
"""
pytree = {
'a': jnp.ones((16, 16)),
'b': {
'c': jnp.ones((16, 16)),
'd': jnp.ones((16, 16)),
},
}
sharding_rule = TreePathShardingRule(
('a', PS('x', 'y')),
('b/c', PS('x')),
strict=True,
)
with self.assertRaises(ValueError):
matched_partition_specs = sharding_rule.apply(pytree)
class PolicyShardingRuleTest(parameterized.TestCase):
def test_policy_sharding_rule(self):
pytree = {
'a': jnp.ones((16, 16)),
'b': {
'c': jnp.ones((16, 16), dtype=jnp.float32),
'd': jnp.ones((16, 16), dtype=jnp.int32),
},
'e': jnp.ones([]),
}
def policy_fn(path, value):
if path == 'a':
return PS('x')
elif value.dtype == jnp.int32:
return PS('y')
elif len(value.shape) == 0:
return PS('z')
else:
return PS()
sharding_rule = PolicyShardingRule(policy_fn)
matched_partition_specs = sharding_rule.apply(pytree)
expected_partition_specs = {
'a': PS('x'),
'b': {
'c': PS(),
'd': PS('y'),
},
'e': PS('z'),
}
self.assertEqual(matched_partition_specs, expected_partition_specs)
class MeshShardingHelperTest(parameterized.TestCase):
@parameterized.parameters(32, 64, 192)
def test_sjit_static_args(self, dim):
mesh = MeshShardingHelper(
axis_dims=(2, 4),
axis_names=('x', 'y'),
)
def static_arg_fn():
return 3.0
@partial(
mesh.sjit,
in_shardings=(PS(), PS()),
static_argnums=(1,),
)
def sharded_fn(x, y, z):
return jnp.zeros((dim, dim)) + x + z + static_arg_fn()
output = sharded_fn(1.0, static_arg_fn, 2.0)
self.assertTrue(jnp.all(output == 6.0))
@partial(
mesh.sjit,
in_shardings=(PS(), PS(), PS()),
)
def sharded_fn(x, y, z):
return jnp.zeros((dim, dim)) + x + z + static_arg_fn()
with self.assertRaises(TypeError):
output = sharded_fn(1.0, static_arg_fn, 2.0)
@parameterized.parameters(32, 64, 192)
def test_sjit_out_shardings(self, dim):
mesh = MeshShardingHelper(
axis_dims=(2, 4),
axis_names=('x', 'y'),
)
sharding_rule = TreePathShardingRule(
('a', PS('x', 'y')),
('b', PS('y', 'x')),
)
@partial(
mesh.sjit,
out_shardings=(sharding_rule, PS(('x', 'y')))
)
def sharded_fn(x):
output_rule = {
'a': jnp.zeros((dim, dim)),
'b': jnp.zeros((dim, dim)),
}
output_ps = jnp.zeros((dim, dim))
return output_rule, output_ps
output_rule, output_ps = sharded_fn(1.0)
self.assertEqual(
output_rule['a'].sharding,
NamedSharding(mesh.mesh, PS('x', 'y'))
)
self.assertEqual(
output_rule['b'].sharding,
NamedSharding(mesh.mesh, PS('y', 'x'))
)
self.assertEqual(
output_ps.sharding,
NamedSharding(mesh.mesh, PS(('x', 'y')))
)
def test_sjit_get_global_mesh(self):
mesh = MeshShardingHelper(
axis_dims=(2, 4),
axis_names=('x', 'y'),
)
@mesh.sjit
def sharded_fn(x):
self.assertTrue(mesh is get_global_mesh_helper())
self.assertTrue(mesh.mesh is get_global_mesh())
return x
sharded_fn(1.0)
@parameterized.parameters(32, 64, 192)
def test_with_sharding_constraint(self, dim):
mesh = MeshShardingHelper(
axis_dims=(2, 4),
axis_names=('x', 'y'),
)
sharding_rule = PolicyShardingRule(lambda path, value: PS('x', 'y'))
@mesh.sjit
def rule_constrained_fn(x):
return with_sharding_constraint(
x, PolicyShardingRule(lambda path, value: PS('x', 'y'))
)
@mesh.sjit
def spec_constrained_fn(x):
return with_sharding_constraint(
x, PS('x', 'y')
)
@jax.jit
def reference_fn(x):
return jax.lax.with_sharding_constraint(
x, NamedSharding(mesh.mesh, PS('x', 'y'))
)
sjit_rule_output = rule_constrained_fn(jnp.ones((dim, dim)))
sjit_spec_output = spec_constrained_fn(jnp.ones((dim, dim)))
reference_output = reference_fn(jnp.ones((dim, dim)))
self.assertEqual(sjit_rule_output.sharding, reference_output.sharding)
self.assertEqual(sjit_spec_output.sharding, reference_output.sharding)
@parameterized.parameters(32, 64, 192)
def test_with_sharding_annotation(self, dim):
mesh = MeshShardingHelper(
axis_dims=(2, 4),
axis_names=('x', 'y'),
)
sharding_rule = PolicyShardingRule(lambda path, value: PS('x', 'y'))
@partial(
mesh.sjit,
annotation_shardings={'activation': sharding_rule}
)
def rule_constrained_fn(x):
return with_sharding_annotation(
x, 'activation'
)
@partial(
mesh.sjit,
annotation_shardings={'activation': PS('x', 'y')}
)
def spec_constrained_fn(x):
return with_sharding_annotation(
x, 'activation'
)
@jax.jit
def reference_fn(x):
return jax.lax.with_sharding_constraint(
x,
NamedSharding(mesh.mesh, PS('x', 'y'))
)
sjit_rule_output = rule_constrained_fn(jnp.ones((dim, dim)))
sjit_spec_output = spec_constrained_fn(jnp.ones((dim, dim)))
reference_output = reference_fn(jnp.ones((dim, dim)))
self.assertEqual(sjit_rule_output.sharding, reference_output.sharding)
self.assertEqual(sjit_spec_output.sharding, reference_output.sharding)
@parameterized.parameters(32, 64, 192)
def test_local_data_to_global_array(self, dim):
mesh = MeshShardingHelper(
axis_dims=(2, 4),
axis_names=('x', 'y'),
)
data = np.ones((dim, dim))
global_array_x = mesh.local_data_to_global_array(
data, mesh_axis_subset=('x',)
)
global_array_y = mesh.local_data_to_global_array(
data, mesh_axis_subset=('y',)
)
global_array_xy = mesh.local_data_to_global_array(
data, mesh_axis_subset=('x', 'y')
)
self.assertEqual(global_array_x.shape, data.shape)
self.assertEqual(global_array_y.shape, data.shape)
self.assertEqual(global_array_xy.shape, data.shape)
self.assertEqual(global_array_x.sharding, NamedSharding(mesh.mesh, PS('x')))
self.assertEqual(global_array_y.sharding, NamedSharding(mesh.mesh, PS('y')))
self.assertEqual(global_array_xy.sharding, NamedSharding(mesh.mesh, PS(('x', 'y'))))
global_array_x = mesh.local_data_to_global_array(
data, batch_axis=1, mesh_axis_subset=('x',)
)
global_array_y = mesh.local_data_to_global_array(
data, batch_axis=1, mesh_axis_subset=('y',)
)
global_array_xy = mesh.local_data_to_global_array(
data, batch_axis=1, mesh_axis_subset=('x', 'y')
)
self.assertEqual(global_array_x.shape, data.shape)
self.assertEqual(global_array_y.shape, data.shape)
self.assertEqual(global_array_xy.shape, data.shape)
self.assertEqual(global_array_x.sharding, NamedSharding(mesh.mesh, PS(None, 'x')))
self.assertEqual(global_array_y.sharding, NamedSharding(mesh.mesh, PS(None, 'y')))
self.assertEqual(global_array_xy.sharding, NamedSharding(mesh.mesh, PS(None, ('x', 'y'))))
if __name__ == '__main__':
absltest.main()