Permute values for VBE with keyed_jagged_index_select #1682
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
In the case of VBE, due to variable batch per feature, we permute values via length per key. For high pooling factor inputs this shrinks the number of permutes and greatly increases the amount of data needed to be copied per each permute index.
The implementation of
permute_1d_sparse_data
processes 32 indices (of the recat tensor) in one thread block (each index is processed by 32 threads). So, we use 1 thread block to process all indices.We further improved this implementation further to separate the block into 4 thread blocks of 256 threads. This allows more threads to process each permute index, however the sm utilization is still low.
keyed_jagged_index_select_dim1
parallelizes work for a single permute index across multiple thread blocksnote that it only works on cuda.
Differential Revision: D53432094