Skip to content

Commit 9bd58a6

Browse files
committed
register op
1 parent f4b119e commit 9bd58a6

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

torch/distributed/_symmetric_memory/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,7 +1376,7 @@ def _maybe_convert_scalar_types_to_dtypes(
13761376
class Work(_Work):
13771377
def __init__(self) -> None:
13781378
super().__init__()
1379-
self.event = torch.cuda.Event()
1379+
self.event = torch.xpu.Event()
13801380
self.event.record()
13811381

13821382
def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool:
@@ -1421,7 +1421,7 @@ def _low_contention_all_gather_meta(
14211421
group_size = c10d._get_group_size_by_name(group_name)
14221422
return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:])
14231423

1424-
1424+
@torch.library.impl(lib, "_low_contention_all_gather", "XPU")
14251425
@torch.library.impl(lib, "_low_contention_all_gather", "CUDA")
14261426
def _low_contention_all_gather(
14271427
tensor: torch.Tensor,
@@ -1454,7 +1454,7 @@ def _low_contention_all_gather(
14541454
output = tensor.new_empty(tensor.shape[0] * world_size, *tensor.shape[1:])
14551455
chunks = output.chunk(world_size)
14561456

1457-
_get_backend_stream().wait_stream(torch.cuda.current_stream())
1457+
_get_backend_stream().wait_stream(torch.xpu.current_stream())
14581458
with _get_backend_stream():
14591459
if not input_is_symm_mem:
14601460
local_buf = symm_mem.get_buffer(rank, tensor.shape, tensor.dtype)
@@ -1492,7 +1492,7 @@ def _low_contention_reduce_scatter_with_symm_mem_input(
14921492
a2a_res = torch.empty_like(tensor)
14931493
chunks = a2a_res.chunk(world_size)
14941494

1495-
_get_backend_stream().wait_stream(torch.cuda.current_stream())
1495+
_get_backend_stream().wait_stream(torch.xpu.current_stream())
14961496
with _get_backend_stream():
14971497
# pull + offline reduction
14981498
symm_mem.barrier()
@@ -1529,7 +1529,7 @@ def _low_contention_reduce_scatter_with_workspace(
15291529
assert tensor.shape[0] % world_size == 0
15301530
chunks = tensor.chunk(world_size)
15311531

1532-
_get_backend_stream().wait_stream(torch.cuda.current_stream())
1532+
_get_backend_stream().wait_stream(torch.xpu.current_stream())
15331533
with _get_backend_stream():
15341534
# push + offline reduction
15351535
workspace.barrier()
@@ -1552,7 +1552,7 @@ def _low_contention_reduce_scatter_with_workspace(
15521552
torch._C._distributed_c10d._register_work(ret, Work())
15531553
return ret
15541554

1555-
1555+
@torch.library.impl(lib, "_low_contention_reduce_scatter", "XPU")
15561556
@torch.library.impl(lib, "_low_contention_reduce_scatter", "CUDA")
15571557
def _low_contention_reduce_scatter(
15581558
tensor: torch.Tensor,

0 commit comments

Comments
 (0)