@@ -65,6 +65,7 @@ def dummy_init_pg() -> None:
65
65
def _test_pg (
66
66
pg : ProcessGroup ,
67
67
example_tensor : torch .Tensor = torch .randn ((2 , 3 ), dtype = torch .float32 ),
68
+ skip : list [str ] = [],
68
69
) -> Dict [str , dist ._Work ]:
69
70
"""
70
71
Helper function to test a set of collective operations on a given process group.
@@ -124,6 +125,8 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
124
125
works : Dict [str , dist ._Work ] = {}
125
126
126
127
for coll_str , args in collectives :
128
+ if coll_str in skip :
129
+ continue
127
130
try :
128
131
coll = getattr (pg , coll_str )
129
132
work = coll (* args )
@@ -496,7 +499,12 @@ def run_reduce_scatter_tensor_coalesced_test(
496
499
497
500
498
501
class ProcessGroupTest (TestCase ):
499
- def test_gloo_apis (self ) -> None :
502
+ @parameterized .expand (["cpu" , "cuda" ])
503
+ def test_gloo_apis (self , device : str ) -> None :
504
+ if device == "cuda" and not torch .cuda .is_available ():
505
+ self .skipTest ("CUDA is not available" )
506
+ return
507
+
500
508
store = TCPStore (
501
509
host_name = "localhost" , port = 0 , is_master = True , wait_for_workers = False
502
510
)
@@ -507,11 +515,23 @@ def test_gloo_apis(self) -> None:
507
515
508
516
self .assertEqual (pg .size (), 1 )
509
517
510
- _test_pg (pg )
518
+ _test_pg (
519
+ pg ,
520
+ torch .tensor ([2 ], device = device ),
521
+ skip = (
522
+ # https://github.com/pytorch/pytorch/issues/152645
523
+ [
524
+ "allreduce_coalesced" ,
525
+ "allgather_into_tensor_coalesced" ,
526
+ ]
527
+ if device == "cuda"
528
+ else []
529
+ ),
530
+ )
511
531
512
- m = nn .Linear (3 , 4 )
532
+ m = nn .Linear (3 , 4 ). to ( device )
513
533
m = torch .nn .parallel .DistributedDataParallel (m , process_group = pg )
514
- m (torch .rand (2 , 3 ))
534
+ m (torch .rand (2 , 3 , device = device ))
515
535
516
536
def test_gloo_timeout (self ) -> None :
517
537
store = TCPStore (
0 commit comments