Skip to content

Commit

Permalink
Optimize permuting values for VBE with keyed_jagged_index_select (#1682)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1682

In the case of VBE, due to variable batch per feature, we permute values via length per key. For high pooling factor inputs this shrinks the number of permutes and greatly increases the amount of data needed to be copied per each permute index.

The implementation of `permute_1d_sparse_data` processes 32 indices (of the recat tensor) in one thread block (each index is processed by 32 threads).  So, we use 1 thread block to process all indices.

We further improved this implementation further to separate the block into 4 thread blocks of 256 threads. This allows more threads to process each permute index, however the sm utilization is still low.

`keyed_jagged_index_select_dim1` parallelizes work for a single permute index across multiple thread blocks
note that it only works on cuda.

Reviewed By: sryap, AlbertDachiChen

Differential Revision: D53432094

fbshipit-source-id: 56395687c06f90b72cff47978aadbc6b532ff5ef
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Feb 6, 2024
1 parent 7edcb2d commit f934ec9
Showing 1 changed file with 48 additions and 12 deletions.
60 changes: 48 additions & 12 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,48 @@ def _arange(*args, **kwargs) -> torch.Tensor:
return torch.arange(*args, **kwargs)


def _permute_variable_stride_values(
values: torch.Tensor,
length_per_key: torch.Tensor,
recat: torch.Tensor,
weights: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
For variable stride tensors we permute across length per key, which reduces the
number of permute indices and lengthens each sequence.
`keyed_jagged_index_select_dim1` more efficiently parallelizes work for each permute
index and sequence across multiple thread blocks.
NOTE:
`keyed_jagged_index_select_dim1` is only supported for CUDA.
"""
if values.device.type == "cuda":
output = torch.ops.fbgemm.keyed_jagged_index_select_dim1(
values,
length_per_key,
_to_offsets(length_per_key),
recat,
len(length_per_key),
weights,
# TODO: add selected_lengths_sum once landed to prevent D2H sync
)
permuted_values = output[0]
permuted_weights = None if weights is None else output[2]
else:
(
_,
permuted_values,
permuted_weights,
) = torch.ops.fbgemm.permute_1D_sparse_data(
recat,
length_per_key,
values,
weights,
None,
)
return permuted_values, permuted_weights


class JaggedTensorMeta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta):
pass

Expand Down Expand Up @@ -1638,16 +1680,11 @@ def permute(
None,
None,
)
(
_,
permuted_values,
permuted_weights,
) = torch.ops.fbgemm.permute_1D_sparse_data(
indices_tensor,
length_per_key_tensor,
permuted_values, permuted_weights = _permute_variable_stride_values(
self.values(),
length_per_key_tensor,
indices_tensor,
self.weights_or_none(),
None,
)
else:
(
Expand Down Expand Up @@ -1925,12 +1962,11 @@ def dist_init(
None,
None,
)
(_, values, weights,) = torch.ops.fbgemm.permute_1D_sparse_data(
recat,
length_per_key,
values, weights = _permute_variable_stride_values(
values,
length_per_key,
recat,
weights,
None,
)
if not stride_per_key_per_rank:
stride_per_key_per_rank = [[0]] * len(keys)
Expand Down

0 comments on commit f934ec9

Please sign in to comment.