Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix build for cuda_bf16 #8862

Merged
merged 3 commits into from
Aug 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions oneflow/user/kernels/normalization_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ limitations under the License.
#include "oneflow/core/kernel/new_kernel_util.h"
#include "oneflow/core/kernel/cuda_graph_support.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#if CUDA_VERSION >= 11000
#include "oneflow/core/device/cuda_pseudo_bfloat16.h"
#endif
#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000

namespace oneflow {

Expand Down Expand Up @@ -251,7 +252,7 @@ constexpr int64_t kCudaWarpSize = 32;
template<typename T>
__global__ void ReluGpu(int64_t n, const T* x, T* y, int32_t* mask) {
const int32_t lane_id = threadIdx.x % kCudaWarpSize;
const T zero = static_cast<T>(0);
const T zero = static_cast<T>(0.f);
CUDA_1D_KERNEL_LOOP(i, n) {
const T x_val = x[i];
const bool is_positive = (x_val > zero);
Expand All @@ -264,7 +265,7 @@ __global__ void ReluGpu(int64_t n, const T* x, T* y, int32_t* mask) {
template<typename T>
__global__ void AddReluGpu(int64_t n, const T* x, const T* addend, T* y, int32_t* mask) {
const int32_t lane_id = threadIdx.x % kCudaWarpSize;
const T zero = static_cast<T>(0);
const T zero = static_cast<T>(0.f);
CUDA_1D_KERNEL_LOOP(i, n) {
const T sum = x[i] + addend[i];
const bool is_positive = (sum > zero);
Expand Down Expand Up @@ -296,6 +297,21 @@ __global__ void ReluBackwardGpu(int64_t n, const int32_t* mask, const T* dy, T*
}
}

#if CUDA_VERSION >= 11000

template<>
__global__ void ReluBackwardGpu<nv_bfloat16>(int64_t n, const int32_t* mask, const nv_bfloat16* dy,
nv_bfloat16* addend_diff) {
int32_t lane_id = threadIdx.x % kCudaWarpSize;
CUDA_1D_KERNEL_LOOP(i, n) {
int32_t mask_val = mask[i / kCudaWarpSize];
bool is_positive = mask_val & (1 << lane_id);
addend_diff[i] = static_cast<nv_bfloat16>(static_cast<float>(is_positive)) * dy[i];
}
}

#endif

template<typename T>
void ReluBackward(ep::Stream* stream, int64_t n, const int32_t* mask, const T* dy, T* addend_diff) {
ReluBackwardGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
Expand Down