Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimize permuting values for VBE with keyed_jagged_index_select (#1682)
Summary: Pull Request resolved: #1682 In the case of VBE, due to variable batch per feature, we permute values via length per key. For high pooling factor inputs this shrinks the number of permutes and greatly increases the amount of data needed to be copied per each permute index. The implementation of `permute_1d_sparse_data` processes 32 indices (of the recat tensor) in one thread block (each index is processed by 32 threads). So, we use 1 thread block to process all indices. We further improved this implementation further to separate the block into 4 thread blocks of 256 threads. This allows more threads to process each permute index, however the sm utilization is still low. `keyed_jagged_index_select_dim1` parallelizes work for a single permute index across multiple thread blocks note that it only works on cuda. Reviewed By: sryap, AlbertDachiChen Differential Revision: D53432094 fbshipit-source-id: 56395687c06f90b72cff47978aadbc6b532ff5ef
- Loading branch information