Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permute values for VBE with keyed_jagged_index_select #1682

Closed
wants to merge 1 commit into from

Conversation

joshuadeng
Copy link
Contributor

Summary:
In the case of VBE, due to variable batch per feature, we permute values via length per key. For high pooling factor inputs this shrinks the number of permutes and greatly increases the amount of data needed to be copied per each permute index.

The implementation of permute_1d_sparse_data processes 32 indices (of the recat tensor) in one thread block (each index is processed by 32 threads). So, we use 1 thread block to process all indices.

We further improved this implementation further to separate the block into 4 thread blocks of 256 threads. This allows more threads to process each permute index, however the sm utilization is still low.

keyed_jagged_index_select_dim1 parallelizes work for a single permute index across multiple thread blocks
note that it only works on cuda.

Differential Revision: D53432094

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 5, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D53432094

joshuadeng added a commit to joshuadeng/torchrec that referenced this pull request Feb 5, 2024
Summary:

In the case of VBE, due to variable batch per feature, we permute values via length per key. For high pooling factor inputs this shrinks the number of permutes and greatly increases the amount of data needed to be copied per each permute index.

The implementation of `permute_1d_sparse_data` processes 32 indices (of the recat tensor) in one thread block (each index is processed by 32 threads).  So, we use 1 thread block to process all indices.

We further improved this implementation further to separate the block into 4 thread blocks of 256 threads. This allows more threads to process each permute index, however the sm utilization is still low.

`keyed_jagged_index_select_dim1` parallelizes work for a single permute index across multiple thread blocks
note that it only works on cuda.

Differential Revision: D53432094
…orch#1682)

Summary:

In the case of VBE, due to variable batch per feature, we permute values via length per key. For high pooling factor inputs this shrinks the number of permutes and greatly increases the amount of data needed to be copied per each permute index.

The implementation of `permute_1d_sparse_data` processes 32 indices (of the recat tensor) in one thread block (each index is processed by 32 threads).  So, we use 1 thread block to process all indices.

We further improved this implementation further to separate the block into 4 thread blocks of 256 threads. This allows more threads to process each permute index, however the sm utilization is still low.

`keyed_jagged_index_select_dim1` parallelizes work for a single permute index across multiple thread blocks
note that it only works on cuda.

Differential Revision: D53432094
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D53432094

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D53432094

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants