Skip to content

Commit

Permalink
Optimize permuting values for VBE with keyed_jagged_index_select (#1682)
Browse files Browse the repository at this point in the history
Summary:

In the case of VBE, due to variable batch per feature, we permute values via length per key. For high pooling factor inputs this shrinks the number of permutes and greatly increases the amount of data needed to be copied per each permute index.

The implementation of `permute_1d_sparse_data` processes 32 indices (of the recat tensor) in one thread block (each index is processed by 32 threads).  So, we use 1 thread block to process all indices.

We further improved this implementation further to separate the block into 4 thread blocks of 256 threads. This allows more threads to process each permute index, however the sm utilization is still low.

`keyed_jagged_index_select_dim1` parallelizes work for a single permute index across multiple thread blocks
note that it only works on cuda.

Differential Revision: D53432094
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Feb 5, 2024
1 parent f8f6f61 commit b96e8d0
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,38 @@ def _arange(*args, **kwargs) -> torch.Tensor:
return torch.arange(*args, **kwargs)


def _permute_variable_stride_values(
values: torch.Tensor,
length_per_key: torch.Tensor,
recat: torch.Tensor,
weights: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if values.device.type == "cuda":
output = torch.ops.fbgemm.keyed_jagged_index_select_dim1(
values,
length_per_key,
_to_offsets(length_per_key),
recat,
len(length_per_key),
weights,
)
permuted_values = output[0]
permuted_weights = None if weights is None else output[2]
else:
(
_,
permuted_values,
permuted_weights,
) = torch.ops.fbgemm.permute_1D_sparse_data(
recat,
length_per_key,
values,
weights,
None,
)
return permuted_values, permuted_weights


class JaggedTensorMeta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta):
pass

Expand Down Expand Up @@ -1638,16 +1670,11 @@ def permute(
None,
None,
)
(
_,
permuted_values,
permuted_weights,
) = torch.ops.fbgemm.permute_1D_sparse_data(
indices_tensor,
length_per_key_tensor,
permuted_values, permuted_weights = _permute_variable_stride_values(
self.values(),
length_per_key_tensor,
indices_tensor,
self.weights_or_none(),
None,
)
else:
(
Expand Down Expand Up @@ -1925,12 +1952,11 @@ def dist_init(
None,
None,
)
(_, values, weights,) = torch.ops.fbgemm.permute_1D_sparse_data(
recat,
length_per_key,
values, weights = _permute_variable_stride_values(
values,
length_per_key,
recat,
weights,
None,
)
if not stride_per_key_per_rank:
stride_per_key_per_rank = [[0]] * len(keys)
Expand Down

0 comments on commit b96e8d0

Please sign in to comment.