6
6
import unittest
7
7
from unittest .mock import patch
8
8
import sys
9
- import os
10
9
11
10
import torch
12
11
from torch import nn
27
26
from torch_xla ._internal import tpu
28
27
29
28
30
- def should_convert_to_shardy ():
31
- return os .environ .get ("CONVERT_SHLO_TO_SHARDY" ,
32
- "" ).lower () in ("1" , "true" , "yes" )
33
-
34
-
35
29
class BasicXlaShardingTest (test_xla_sharding_base .XlaShardingTest ):
36
30
37
31
@classmethod
38
32
def setUpClass (cls ):
39
33
super ().setUpClass ()
34
+ cls .convert_to_shardy = xu .check_env_flag ("CONVERT_SHLO_TO_SHARDY" )
40
35
41
36
def test_xla_sharded_tensor (self ):
42
37
partition_spec = (0 , 1 )
@@ -244,7 +239,7 @@ def test_custom_tile_assignment(self):
244
239
if self .n_devices > 1 :
245
240
annotation = '{devices=[1,%d]%s}' % (self .n_devices , ',' .join (
246
241
[str (i ) for i in reversed (range (self .n_devices ))]))
247
- if should_convert_to_shardy () :
242
+ if self . convert_to_shardy :
248
243
annotation = '{devices=[1,%d]<=[%d]}' % (self .n_devices , self .n_devices )
249
244
self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
250
245
@@ -260,7 +255,7 @@ def test_mark_sharding_2d(self):
260
255
if self .n_devices > 1 :
261
256
annotation = '{devices=[1,%d]%s}' % (self .n_devices , ',' .join (
262
257
[str (i ) for i in range (self .n_devices )]))
263
- if should_convert_to_shardy () :
258
+ if self . convert_to_shardy :
264
259
annotation = '{devices=[1,%d]<=[%d]}' % (self .n_devices , self .n_devices )
265
260
self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt1 ))
266
261
@@ -281,7 +276,7 @@ def test_mark_sharding_4d(self):
281
276
annotation = '{devices=[1,1,%d,%d]%s}' % (
282
277
z_dim , self .n_devices // z_dim , ',' .join (
283
278
[str (i ) for i in range (self .n_devices )]))
284
- if should_convert_to_shardy () :
279
+ if self . convert_to_shardy :
285
280
annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim , self .n_devices //
286
281
z_dim , self .n_devices )
287
282
self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
@@ -418,7 +413,7 @@ def test_tupled_partition_spec(self):
418
413
xs .mark_sharding (t , mesh , ((0 , 1 ),))
419
414
annotation = "{devices=[%d]%s}" % (self .n_devices , ',' .join (
420
415
str (x ) for x in range (self .n_devices )))
421
- if should_convert_to_shardy () :
416
+ if self . convert_to_shardy :
422
417
annotation = "{devices=[%d]<=[%d]}" % (self .n_devices , self .n_devices )
423
418
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
424
419
@@ -432,7 +427,7 @@ def test_named_partial_tupled_partition_spec(self):
432
427
xs .mark_sharding (t , mesh , (('r' , 'b' ), None ))
433
428
annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
434
429
self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices )))
435
- if should_convert_to_shardy () :
430
+ if self . convert_to_shardy :
436
431
annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
437
432
self .n_devices // 2 , self .n_devices )
438
433
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
@@ -442,7 +437,7 @@ def test_named_partial_tupled_partition_spec(self):
442
437
xs .mark_sharding (u , mesh , (None , ('b' , 'm' )))
443
438
annotation = "{devices=[1,%d]%s}" % (self .n_devices , ',' .join (
444
439
str (x ) for x in range (self .n_devices )))
445
- if should_convert_to_shardy () :
440
+ if self . convert_to_shardy :
446
441
annotation = "{devices=[1,%d]<=[%d]}" % (self .n_devices , self .n_devices )
447
442
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (u ), annotation )
448
443
@@ -452,7 +447,7 @@ def test_named_partial_tupled_partition_spec(self):
452
447
device_order = mesh .get_logical_mesh ().transpose ((0 , 2 , 1 )).flatten ()
453
448
annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
454
449
self .n_devices // 2 , ',' .join (str (x ) for x in device_order ))
455
- if should_convert_to_shardy () :
450
+ if self . convert_to_shardy :
456
451
annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
457
452
self .n_devices // 2 , self .n_devices // 2 )
458
453
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
@@ -463,7 +458,7 @@ def test_named_partial_tupled_partition_spec(self):
463
458
device_order = mesh .get_logical_mesh ().transpose ((2 , 1 , 0 )).flatten ()
464
459
annotation = "{devices=[1,%d]%s}" % (self .n_devices , ',' .join (
465
460
str (x ) for x in device_order ))
466
- if should_convert_to_shardy () :
461
+ if self . convert_to_shardy :
467
462
annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self .n_devices ,
468
463
self .n_devices // 2 )
469
464
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
@@ -478,7 +473,7 @@ def test_multiple_tuples_in_spec(self):
478
473
xs .mark_sharding (t , mesh , (('a' , 'b' ), ('c' , 'd' )))
479
474
annotation = "{devices=[2,%d]%s}" % (self .n_devices // 2 , ',' .join (
480
475
str (x ) for x in range (self .n_devices )))
481
- if should_convert_to_shardy () :
476
+ if self . convert_to_shardy :
482
477
annotation = "{devices=[2,%d]<=[%d]}" % (self .n_devices // 2 ,
483
478
self .n_devices )
484
479
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
@@ -491,7 +486,7 @@ def test_3d_tensor_2d_mesh(self):
491
486
xs .mark_sharding (t , mesh , (None , 0 , 1 ))
492
487
annotation = '{devices=[1,2,%d]%s}' % (self .n_devices // 2 , ',' .join (
493
488
str (x ) for x in range (self .n_devices )))
494
- if should_convert_to_shardy () :
489
+ if self . convert_to_shardy :
495
490
annotation = '{devices=[1,2,%d]<=[%d]}' % (self .n_devices // 2 ,
496
491
self .n_devices )
497
492
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
@@ -1013,8 +1008,7 @@ def test_op_sharding_cache(self):
1013
1008
1014
1009
t = torch .randn (1 , self .n_devices ).to ('xla' )
1015
1010
xs .mark_sharding (t , mesh , (0 , 1 ))
1016
- counter_name = "CreateIotaOpSharding" if should_convert_to_shardy (
1017
- ) else "CreateOpSharding"
1011
+ counter_name = "CreateIotaOpSharding" if self .convert_to_shardy else "CreateOpSharding"
1018
1012
self .assertIn (counter_name , met .counter_names ())
1019
1013
self .assertEqual (met .counter_value (counter_name ), 1 )
1020
1014
@@ -1435,7 +1429,7 @@ def test_data_loader_with_sharding(self):
1435
1429
data , _ = iter (train_device_loader ).__next__ ()
1436
1430
self .assertEqual (data .size (), torch .Size ([8 , 3 , 64 , 64 ]))
1437
1431
annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' .join ([str (i ) for i in range (mesh .size ())])} }}"
1438
- if should_convert_to_shardy () :
1432
+ if self . convert_to_shardy :
1439
1433
annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]<=[{ mesh .size ()} ]}}"
1440
1434
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (data ), annotation )
1441
1435
@@ -1458,7 +1452,7 @@ def test_data_loader_with_non_batch_size(self):
1458
1452
data , _ = iter (train_device_loader ).__next__ ()
1459
1453
self .assertEqual (data .size (), torch .Size ([mesh .size () - 1 , 3 , 64 , 64 ]))
1460
1454
annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' .join ([str (i ) for i in range (mesh .size ())])} }}"
1461
- if should_convert_to_shardy () :
1455
+ if self . convert_to_shardy :
1462
1456
annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]<=[{ mesh .size ()} ]}}"
1463
1457
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (data ), annotation )
1464
1458
0 commit comments