Skip to content

Modify torchrec code to use reduce_scatter_v and all_gather_v #580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,27 @@ class AllGatherBaseInfo(object):
codecs: Optional[QuantizedCommCodecs]


@dataclass
class ReduceScatterVInfo(object):
"""
The data class that collects the attributes when calling the `reduce_scatter_v_pooled`
operation.

Attributes:
input_sizes (List[torch.Size]): the sizes of the input tensors. This remembers the
sizes of the input tensors when running the backward pass and producing the
gradient.
input_splits (List[int]): the splits of the input tensors along dim0.
total_input_size: (List[int]): total input size
"""

input_sizes: List[torch.Size]
input_splits: List[int]
equal_splits: bool
total_input_size: List[int]
codecs: Optional[QuantizedCommCodecs]


@dataclass
class All2AllDenseInfo(object):
"""
Expand Down Expand Up @@ -519,6 +540,58 @@ def all_gather_base_pooled(
return myreq


def reduce_scatter_v_pooled(
input: Tensor,
input_splits: List[int],
group: Optional[dist.ProcessGroup] = None,
codecs: Optional[QuantizedCommCodecs] = None,
) -> Awaitable[Tensor]:
"""
Performs reduce-scatter-v operation for a pooled embeddings tensor split unevenly into world
size number of chunks. The result of the reduce operation gets scattered to all
processes in the group according to input_splits.

Args:
input (Tensor): tensors to scatter, one per rank.
input_splits (List[int]): input splits.
group (Optional[dist.ProcessGroup]): The process group to work on. If None, the
default process group will be used.

Returns:
Awaitable[Tensor]: async work handle (Awaitable), which can be `wait()` later to get the resulting tensor.

.. warning::
`reduce_scatter_v_pooled` is experimental and subject to change.
"""

if group is None:
group = dist.distributed_c10d._get_default_group()

if dist.get_world_size(group) <= 1:
return NoWait(input)

myreq = Request(group)
input_size = list(input.size())
input_sizes = [
torch.Size(
[ip_split if d == 0 else input_size[d] for d in range(len(input_size))]
)
for ip_split in input_splits
]
equal_splits = all(ip_split == input_splits[0] for ip_split in input_splits)

rsvi = ReduceScatterVInfo(
input_sizes=input_sizes,
input_splits=input_splits,
equal_splits=equal_splits,
total_input_size=input_size,
codecs=codecs,
)
# pyre-fixme[16]: `ReduceScatterV_Req` has no attribute `apply`.
ReduceScatterV_Req.apply(group, myreq, rsvi, input)
return myreq


# TODO: improve performance of _recat_pooled_embedding_grad_out, see T87591139
def _recat_pooled_embedding_grad_out(
grad_output: Tensor, num_features_per_rank: List[int]
Expand Down Expand Up @@ -1181,3 +1254,100 @@ def backward(ctx, grad_outputs: Tensor) -> Tuple[None, None, Tensor]:
myreq.req = req
myreq.tensor = grad_input
return (None, None, grad_outputs)


class ReduceScatterV_Req(Function):
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
# pyre-fixme[2]: Parameter must be annotated.
ctx,
pg: dist.ProcessGroup,
myreq: Request[Tensor],
rsi: ReduceScatterVInfo,
input: Tensor,
) -> Tensor:
my_rank = dist.get_rank(pg)
output = input.new_empty(rsi.input_sizes[my_rank])

# Use dist._reduce_scatter_base when a vector reduce-scatter is not needed
# else use dist.reduce_scatter which internally supports vector reduce-scatter
if rsi.equal_splits:
with record_function("## reduce_scatter_base ##"):
req = dist._reduce_scatter_base(output, input, group=pg, async_op=True)
else:
with record_function("## reduce_scatter_v ##"):
req = dist.reduce_scatter(
output,
list(torch.split(input, rsi.input_splits)),
group=pg,
async_op=True,
)

myreq.req = req
myreq.tensor = output
myreq.wait_function = ReduceScatterV_Wait
myreq.rsi = rsi
ctx.myreq = myreq
ctx.pg = pg

return output

@staticmethod
# pyre-fixme[2]: Parameter must be annotated.
def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]:
myreq = ctx.myreq
myreq.req.wait()
myreq.req = None
grad_input = myreq.tensor
# Make it equivalent to running on a single rank.
if GRADIENT_DIVISION:
grad_input.div_(dist.get_world_size(ctx.pg))
myreq.tensor = None
return (None, None, None, grad_input)


class ReduceScatterV_Wait(Function):
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
# pyre-fixme[2]: Parameter must be annotated.
ctx,
pg: dist.ProcessGroup,
myreq: Request[Tensor],
output: Tensor,
) -> Tensor:
myreq.req.wait()
myreq.req = None
myreq.tensor = None
ctx.myreq = myreq
ctx.pg = pg
return output

@staticmethod
# pyre-fixme[14]: `backward` overrides method defined in `Function` inconsistently.
# pyre-fixme[2]: Parameter must be annotated.
def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
myreq = ctx.myreq
rsi = myreq.rsi
grad_input = grad_output.new_empty(rsi.total_input_size)

if rsi.equal_splits:
with record_function("## reduce_scatter_base_bw (all_gather) ##"):
req = dist._all_gather_base(
grad_input,
grad_output.contiguous(),
group=ctx.pg,
async_op=True,
)
else:
with record_function("## reduce_scatter_v_bw (all_gather_v) ##"):
req = dist.all_gather(
list(torch.split(grad_input, rsi.input_splits)),
grad_output.contiguous(),
group=ctx.pg,
async_op=True,
)
myreq.req = req
myreq.tensor = grad_input
return (None, None, grad_output)
60 changes: 59 additions & 1 deletion torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
alltoall_pooled,
alltoall_sequence,
reduce_scatter_base_pooled,
reduce_scatter_pooled,
reduce_scatter_v_pooled,
)
from torchrec.distributed.types import Awaitable, NoWait, QuantizedCommCodecs
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
Expand Down Expand Up @@ -874,6 +874,64 @@ def forward(self, local_emb: torch.Tensor) -> PooledEmbeddingsAwaitable:
return PooledEmbeddingsAwaitable(tensor_awaitable=tensor_awaitable)


class PooledEmbeddingsReduceScatterV(nn.Module):
"""
The module class that wraps reduce-scatter-v communication primitive for pooled
embedding communication in row-wise and twrw sharding.

For pooled embeddings, we have a local model-parallel output tensor with a layout of
[num_buckets x batch_size, dimension]. We need to sum over num_buckets dimension
across batches. We split tensor along the first dimension into unequal chunks (tensor
slices of different buckets) according to input_splits and reduce them into the output
tensor and scatter the results for corresponding ranks.

The class returns the async `Awaitable` handle for pooled embeddings tensor.
The reduce-scatter-v is only available for NCCL backend.

Args:
pg (dist.ProcessGroup): The process group that the reduce-scatter communication
happens within.
codecs (Optional[QuantizedCommCodecs]): Quantization codec

Example::

init_distributed(rank=rank, size=2, backend="nccl")
pg = dist.new_group(backend="nccl")
input = torch.randn(2 * 2, 2)
input_splits = [1,3]
m = PooledEmbeddingsReduceScatterV(pg)
output = m(input, input_splits)
tensor = output.wait()
"""

def __init__(
self,
pg: dist.ProcessGroup,
codecs: Optional[QuantizedCommCodecs] = None,
) -> None:
super().__init__()
self._pg = pg
self._codecs = codecs

def forward(
self, local_embs: torch.Tensor, input_splits: List[int]
) -> PooledEmbeddingsAwaitable:
"""
Performs reduce scatter v operation on pooled embeddings tensor.

Args:
local_embs (torch.Tensor): tensor of shape [num_buckets x batch_size, dimension].
input_splits (List[int]): list of splits for local_embs dim0.

Returns:
PooledEmbeddingsAwaitable: awaitable of pooled embeddings of tensor of shape [batch_size, dimension].
"""
tensor_awaitable = reduce_scatter_v_pooled(
local_embs, input_splits, self._pg, codecs=self._codecs
)
return PooledEmbeddingsAwaitable(tensor_awaitable=tensor_awaitable)


class SequenceEmbeddingsAwaitable(Awaitable[torch.Tensor]):
"""
Awaitable for sequence embeddings after collective operation.
Expand Down
16 changes: 5 additions & 11 deletions torchrec/distributed/sharding/vb_rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
import torch.distributed as dist
from torchrec.distributed.dist_data import PooledEmbeddingsReduceScatter
from torchrec.distributed.dist_data import PooledEmbeddingsReduceScatterV
from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup
from torchrec.distributed.embedding_sharding import (
BaseEmbeddingLookup,
Expand Down Expand Up @@ -161,7 +161,6 @@ def __init__(self, awaitable: Awaitable[torch.Tensor], batch_size: int) -> None:

def _wait_impl(self) -> torch.Tensor:
embedding = self._awaitable.wait()
embedding = torch.narrow(embedding, 0, 0, self._batch_size)

return embedding

Expand All @@ -174,24 +173,19 @@ def __init__(
super().__init__()
self._workers: int = pg.size()
self._rank: int = pg.rank()
self._dist = PooledEmbeddingsReduceScatter(pg)
self._dist = PooledEmbeddingsReduceScatterV(pg)

def forward(
self,
local_embs: torch.Tensor,
sharding_ctx: VariableBatchShardingContext,
) -> Awaitable[torch.Tensor]:
batch_size_per_rank_tensor = sharding_ctx.batch_size_per_rank_tensor
batch_size_per_rank = sharding_ctx.batch_size_per_rank
max_length = max(batch_size_per_rank)
batch_size = batch_size_per_rank[self._rank]
packed_pooled_embs = torch.ops.fbgemm.pack_segments(
t_in=local_embs,
lengths=batch_size_per_rank_tensor,
max_length=max_length,
)

awaitable_tensor = self._dist(
packed_pooled_embs.view(self._workers * max_length, -1)
local_embs.view(sum(batch_size_per_rank), -1),
input_splits=batch_size_per_rank,
)
return VariableBatchRwEmbeddingDistAwaitable(awaitable_tensor, batch_size)

Expand Down
30 changes: 6 additions & 24 deletions torchrec/distributed/sharding/vb_twrw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch.distributed as dist
from torchrec.distributed.dist_data import (
PooledEmbeddingsAllToAll,
PooledEmbeddingsReduceScatter,
PooledEmbeddingsReduceScatterV,
)
from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup
from torchrec.distributed.embedding_sharding import (
Expand Down Expand Up @@ -248,7 +248,7 @@ def __init__(
self._intra_pg: dist.ProcessGroup = intra_pg
self._cross_pg: dist.ProcessGroup = cross_pg
self._device: Optional[torch.device] = device
self._intra_dist = PooledEmbeddingsReduceScatter(intra_pg)
self._intra_dist = PooledEmbeddingsReduceScatterV(intra_pg)
self._cross_dist = PooledEmbeddingsAllToAll(
cross_pg,
dim_sum_per_node,
Expand All @@ -270,32 +270,14 @@ def forward(
self._cross_pg.size(),
sharding_ctx.batch_size_per_rank,
)
# Pad each chunk to same size, prepare for ReduceScatter
# Skip padding when a host has no table assigned, in which case its dim is 0
max_length = max(batch_size_sum_by_cross_group)
if local_embs.shape[1] != 0:
# pyre-fixme[28]: Unexpected keyword argument `pin_memory`.
lengths = torch.tensor(
batch_size_sum_by_cross_group,
pin_memory=True,
).to(device=self._device, non_blocking=True)
local_embs = torch.ops.fbgemm.pack_segments(
t_in=local_embs,
max_length=max_length,
lengths=lengths,
)
# Perform ReduceScatter within one host

# Perform ReduceScatterV within one host
lengths = batch_size_sum_by_cross_group
rs_result = self._intra_dist(
local_embs.view(self._intra_pg.size() * max_length, -1)
local_embs.view(sum(lengths), -1), input_splits=lengths
).wait()

local_rank = self._rank % self._intra_pg.size()
rs_result = torch.narrow(
rs_result,
0,
0,
batch_size_sum_by_cross_group[local_rank],
)

return self._cross_dist(
rs_result,
Expand Down
Loading