Skip to content

Commit

Permalink
Optimize _length_per_key_from_stride_per_key with segment sum csr w…
Browse files Browse the repository at this point in the history
…hen appropriate (pytorch#1699)

Summary:
Pull Request resolved: pytorch#1699

`segment_sum_csr` outperforms the torch split/cat ops under certain conditions.

However there is performance degradation when
  1. the number of segments is small
  2. there are many elements in each segment to sum

Reviewed By: AlbertDachiChen

Differential Revision: D53020207

fbshipit-source-id: 3068479af2068922e4d9e393908c2a90982eb43f
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Feb 14, 2024
1 parent d23b6c7 commit 8761b37
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,12 +693,39 @@ def _maybe_compute_stride_kjt_scripted(
return torch.tensor([_maybe_compute_stride_kjt(keys, stride, lengths, offsets)])


def _use_segment_sum_csr(stride_per_key: List[int]) -> bool:
"""
`segment_sum_csr` performs poorly for small number of segments and many elements
in each segment to sum. This function uses an empirically calculated equation,
derived from fitting a quadratic regression to an interval of elements and elements
per segment that match performance between the kernel and PyTorch solution, to
determine the threshold of when to use `segment_sum_csr`.
"""
elements_per_segment = sum(stride_per_key) / len(stride_per_key)
segment_threshold = int(
1.39771
+ 0.0000312222 * elements_per_segment
+ 1.63949e-10 * elements_per_segment**2
)
return len(stride_per_key) >= segment_threshold


def _length_per_key_from_stride_per_key(
lengths: torch.Tensor, stride_per_key: List[int]
) -> List[int]:
return torch.cat(
[torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key)]
).tolist()
if _use_segment_sum_csr(stride_per_key):
stride_per_key_offsets = _to_offsets(
_pin_and_move(
torch.tensor(stride_per_key, dtype=torch.int32), lengths.device
)
)
return torch.ops.fbgemm.segment_sum_csr(
1, stride_per_key_offsets, lengths
).tolist()
else:
return torch.cat(
[torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key)]
).tolist()


def _maybe_compute_length_per_key(
Expand Down

0 comments on commit 8761b37

Please sign in to comment.