@@ -1376,7 +1376,7 @@ def _maybe_convert_scalar_types_to_dtypes(
13761376class  Work (_Work ):
13771377    def  __init__ (self ) ->  None :
13781378        super ().__init__ ()
1379-         self .event  =  torch .cuda .Event ()
1379+         self .event  =  torch .xpu .Event ()
13801380        self .event .record ()
13811381
13821382    def  wait (self , timeout : timedelta  =  timedelta (seconds = 0 )) ->  bool :
@@ -1421,7 +1421,7 @@ def _low_contention_all_gather_meta(
14211421    group_size  =  c10d ._get_group_size_by_name (group_name )
14221422    return  tensor .new_empty (tensor .shape [0 ] *  group_size , * tensor .shape [1 :])
14231423
1424- 
1424+ @ torch . library . impl ( lib ,  "_low_contention_all_gather" ,  "XPU" ) 
14251425@torch .library .impl (lib , "_low_contention_all_gather" , "CUDA" ) 
14261426def  _low_contention_all_gather (
14271427    tensor : torch .Tensor ,
@@ -1454,7 +1454,7 @@ def _low_contention_all_gather(
14541454    output  =  tensor .new_empty (tensor .shape [0 ] *  world_size , * tensor .shape [1 :])
14551455    chunks  =  output .chunk (world_size )
14561456
1457-     _get_backend_stream ().wait_stream (torch .cuda .current_stream ())
1457+     _get_backend_stream ().wait_stream (torch .xpu .current_stream ())
14581458    with  _get_backend_stream ():
14591459        if  not  input_is_symm_mem :
14601460            local_buf  =  symm_mem .get_buffer (rank , tensor .shape , tensor .dtype )
@@ -1492,7 +1492,7 @@ def _low_contention_reduce_scatter_with_symm_mem_input(
14921492    a2a_res  =  torch .empty_like (tensor )
14931493    chunks  =  a2a_res .chunk (world_size )
14941494
1495-     _get_backend_stream ().wait_stream (torch .cuda .current_stream ())
1495+     _get_backend_stream ().wait_stream (torch .xpu .current_stream ())
14961496    with  _get_backend_stream ():
14971497        # pull + offline reduction 
14981498        symm_mem .barrier ()
@@ -1529,7 +1529,7 @@ def _low_contention_reduce_scatter_with_workspace(
15291529    assert  tensor .shape [0 ] %  world_size  ==  0 
15301530    chunks  =  tensor .chunk (world_size )
15311531
1532-     _get_backend_stream ().wait_stream (torch .cuda .current_stream ())
1532+     _get_backend_stream ().wait_stream (torch .xpu .current_stream ())
15331533    with  _get_backend_stream ():
15341534        # push + offline reduction 
15351535        workspace .barrier ()
@@ -1552,7 +1552,7 @@ def _low_contention_reduce_scatter_with_workspace(
15521552        torch ._C ._distributed_c10d ._register_work (ret , Work ())
15531553        return  ret 
15541554
1555- 
1555+ @ torch . library . impl ( lib ,  "_low_contention_reduce_scatter" ,  "XPU" ) 
15561556@torch .library .impl (lib , "_low_contention_reduce_scatter" , "CUDA" ) 
15571557def  _low_contention_reduce_scatter (
15581558    tensor : torch .Tensor ,
0 commit comments