Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to compile collective for PT > 2.3 #6674

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
DS_COMM_ALL_REDUCE_OFF = False
DS_COMM_REDUCE_OFF = False

def disable_compiler_collective(func):
nelyahu marked this conversation as resolved.
Show resolved Hide resolved
if required_torch_version(min_version=2.3):
return func
return compiler.disable(func)

def build_shm_op():
builder = get_accelerator().create_op_builder("ShareMemCommBuilder")
Expand Down Expand Up @@ -114,7 +118,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
self.shm_comm_op.initialize(self.get_world_size(), self.get_rank())

@classmethod
@compiler.disable
@disable_compiler_collective
def get_all_gather_function(self):
if hasattr(torch.distributed, "all_gather_into_tensor"):
return torch.distributed.all_gather_into_tensor
Expand All @@ -123,7 +127,7 @@ def get_all_gather_function(self):
return None

@classmethod
@compiler.disable
@disable_compiler_collective
def get_reduce_scatter_function(self):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
return torch.distributed.reduce_scatter_tensor
Expand All @@ -146,19 +150,19 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size):
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'

@compiler.disable
@disable_compiler_collective
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

def inference_all_reduce(self, tensor, op, group=None):
if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_'):
if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1:
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False)
else:
return torch.ops.deepspeed.inference_all_reduce_(tensor)

@compiler.disable
@disable_compiler_collective
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
""" proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
Expand All @@ -169,15 +173,15 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group
op = self._reduce_op(op)
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_OFF:
if int(os.getenv('RANK', '0')) == 0:
utils.logger.warning("REDUCE is OFF")
return Noop()
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_SCATTER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -190,7 +194,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def broadcast(self, tensor, src, group=None, async_op=False):
if DS_COMM_BROADCAST_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -199,7 +203,7 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -208,15 +212,15 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False):
else:
return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_all_gather_into_tensor():
return self.all_gather_function(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -234,7 +238,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
@disable_compiler_collective
def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
""""""
assert len(output_tensors) == len(input_tensors), ""
Expand All @@ -258,7 +262,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_
else:
reqs[-1].wait()

@compiler.disable
@disable_compiler_collective
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
Expand All @@ -272,7 +276,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
@disable_compiler_collective
def all_to_all_single(self,
output,
input,
Expand All @@ -287,49 +291,49 @@ def all_to_all_single(self,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False):
return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def send(self, tensor, dst, group=None, tag=0):
return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def recv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def isend(self, tensor, dst, group=None, tag=0):
return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def irecv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
return torch.distributed.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
return torch.distributed.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)

@compiler.disable
@disable_compiler_collective
def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
if group is None:
group = torch.distributed.GroupMember.WORLD
Expand Down
Loading