Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Gather CUDA Kernel #7351

Merged
merged 10 commits into from
Jan 25, 2022
122 changes: 72 additions & 50 deletions oneflow/user/kernels/gather_kernel_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,89 @@ limitations under the License.
#include "oneflow/user/kernels/gather_kernel_util.h"
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include "oneflow/core/common/nd_index_offset_helper.h"
#include <assert.h>

namespace oneflow {

namespace {

template<typename K, typename IDX>
__device__ IDX GetInOffset(const IDX out_offset, const K* indices, const IDX num_indices,
const IDX gather_dim_size, const IDX inner_dim_size, const IDX offset) {
const IDX outer_dim_elem_cnt = num_indices * inner_dim_size;
const IDX outer_idx = out_offset / outer_dim_elem_cnt;
const IDX indices_idx = out_offset % outer_dim_elem_cnt / inner_dim_size;
const IDX inner_idx = out_offset % inner_dim_size;
assert(indices[indices_idx] >= 0);
const IDX idx = indices[indices_idx] - offset;
if (idx >= 0 && idx < gather_dim_size) {
return outer_idx * gather_dim_size * inner_dim_size + idx * inner_dim_size + inner_idx;
template<typename T, typename K, typename IDX>
__global__ void GatherForwardGpu(const IDX elem_cnt, NdIndexOffsetHelper<IDX, 3> in_helper,
NdIndexOffsetHelper<IDX, 3> out_helper, const K* indices,
const T* in, const IDX gather_dim_size, T* out, const IDX offset) {
IDX index[3];
CUDA_1D_KERNEL_LOOP_T(IDX, i, elem_cnt) {
out_helper.OffsetToNdIndex(i, index);
index[1] = indices[index[1]] - offset;
T v{};
if (index[1] >= 0 && index[1] < gather_dim_size) { v = in[in_helper.NdIndexToOffset(index)]; }
out[i] = v;
}
}

bool IsSafeUseIndex32(int64_t outer_dim_size, int64_t gather_dim_size, int64_t inner_dim_size,
int64_t num_indices) {
const int64_t in_elem_cnt = outer_dim_size * gather_dim_size * inner_dim_size;
const int64_t out_elem_cnt = outer_dim_size * num_indices * inner_dim_size;
return std::max(out_elem_cnt, in_elem_cnt) < GetMaxVal<int32_t>() / 2;
}

template<typename T, typename K>
void DispatchIndexSize(ep::Stream* stream, int64_t outer_dim_size, int64_t gather_dim_size,
int64_t inner_dim_size, int64_t num_indices, int64_t offset,
const K* indices, const T* in, T* out) {
const int64_t out_elem_cnt = outer_dim_size * num_indices * inner_dim_size;
if (IsSafeUseIndex32(outer_dim_size, gather_dim_size, inner_dim_size, num_indices)) {
NdIndexOffsetHelper<int32_t, 3> in_helper(outer_dim_size, gather_dim_size, inner_dim_size);
NdIndexOffsetHelper<int32_t, 3> out_helper(outer_dim_size, num_indices, inner_dim_size);
GatherForwardGpu<T, K, int32_t><<<BlocksNum4ThreadsNum(out_elem_cnt), kCudaThreadsNumPerBlock,
0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
out_elem_cnt, in_helper, out_helper, indices, in, gather_dim_size, out, offset);
} else {
return -1;
NdIndexOffsetHelper<int64_t, 3> in_helper(outer_dim_size, gather_dim_size, inner_dim_size);
NdIndexOffsetHelper<int64_t, 3> out_helper(outer_dim_size, num_indices, inner_dim_size);
GatherForwardGpu<T, K, int64_t><<<BlocksNum4ThreadsNum(out_elem_cnt), kCudaThreadsNumPerBlock,
0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
out_elem_cnt, in_helper, out_helper, indices, in, gather_dim_size, out, offset);
}
}

template<typename T, typename K, typename IDX>
__global__ void GatherForwardGpu(const IDX elem_cnt, const K* indices, const IDX num_indices,
const T* in, const IDX gather_dim_size, const IDX inner_dim_size,
T* out, const IDX offset) {
CUDA_1D_KERNEL_LOOP_T(IDX, i, elem_cnt) {
const IDX in_offset =
GetInOffset<K, IDX>(i, indices, num_indices, gather_dim_size, inner_dim_size, offset);
if (in_offset < 0) {
out[i] = 0;
} else {
out[i] = in[in_offset];
}
template<typename K, typename T>
bool TryDispatchMovementType(ep::Stream* stream, int64_t outer_dim_size, int64_t gather_dim_size,
int64_t inner_dim_size, int64_t num_indices, int64_t offset,
const K* indices, const void* in, void* out) {
if (reinterpret_cast<uintptr_t>(in) % sizeof(T) == 0
&& reinterpret_cast<uintptr_t>(out) % sizeof(T) == 0 && inner_dim_size % sizeof(T) == 0) {
DispatchIndexSize<T, K>(stream, outer_dim_size, gather_dim_size, inner_dim_size / sizeof(T),
num_indices, offset, indices, static_cast<const T*>(in),
static_cast<T*>(out));
return true;
} else {
return false;
}
}

bool IsSafeUseIndex32(const Shape& flat_in_shape, const int64_t num_indices) {
const int64_t in_elem_cnt = flat_in_shape.elem_cnt();
const int64_t out_elem_cnt = flat_in_shape.At(0) * num_indices * flat_in_shape.At(2);
return std::max(out_elem_cnt, in_elem_cnt) < GetMaxVal<int32_t>() / 2;
template<typename K>
void DispatchMovementSize(ep::Stream* stream, int64_t outer_dim_size, int64_t gather_dim_size,
int64_t inner_dim_size, int64_t num_indices, int64_t offset,
const K* indices, const void* in, void* out) {
using Func = bool (*)(ep::Stream * stream, int64_t outer_dim_size, int64_t gather_dim_size,
int64_t inner_dim_size, int64_t num_indices, int64_t offset,
const K* indices, const void* in, void* out);
Func funcs[] = {
TryDispatchMovementType<K, ulonglong2>, // 16B
TryDispatchMovementType<K, uint64_t>, // 8B
TryDispatchMovementType<K, uint32_t>, // 4B
TryDispatchMovementType<K, uint16_t>, // 2B
TryDispatchMovementType<K, uint8_t>, // 1B
};
for (size_t i = 0; i < sizeof(funcs) / sizeof(funcs[0]); ++i) {
if (funcs[i](stream, outer_dim_size, gather_dim_size, inner_dim_size, num_indices, offset,
indices, in, out)) {
break;
}
}
}

} // namespace
Expand All @@ -65,28 +107,8 @@ template<typename T, typename K>
struct GatherKernelUtilImpl<DeviceType::kCUDA, T, K> final {
static void Forward(ep::Stream* stream, const K* indices, int64_t num_indices, const T* in,
const Shape& flat_in_shape, T* out, const int64_t offset) {
const int64_t out_elem_cnt = flat_in_shape.At(0) * num_indices * flat_in_shape.At(2);
if (IsSafeUseIndex32(flat_in_shape, num_indices)) {
GatherForwardGpu<T, K, int32_t><<<BlocksNum4ThreadsNum(out_elem_cnt), kCudaThreadsNumPerBlock,
0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
out_elem_cnt, indices, num_indices, in, flat_in_shape.At(1), flat_in_shape.At(2), out,
offset);
} else {
GatherForwardGpu<T, K, int64_t><<<BlocksNum4ThreadsNum(out_elem_cnt), kCudaThreadsNumPerBlock,
0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
out_elem_cnt, indices, num_indices, in, flat_in_shape.At(1), flat_in_shape.At(2), out,
offset);
}
}
};

template<typename K>
struct GatherKernelUtilImpl<DeviceType::kCUDA, float16, K> final {
static void Forward(ep::Stream* stream, const K* indices, int64_t num_indices, const float16* in,
const Shape& flat_in_shape, float16* out, const int64_t offset) {
GatherKernelUtilImpl<DeviceType::kCUDA, half, K>::Forward(
stream, indices, num_indices, reinterpret_cast<const half*>(in), flat_in_shape,
reinterpret_cast<half*>(out), offset);
DispatchMovementSize(stream, flat_in_shape.At(0), flat_in_shape.At(1),
flat_in_shape.At(2) * sizeof(T), num_indices, offset, indices, in, out);
}
};

Expand Down