Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 25 additions & 41 deletions paddle/fluid/operators/group_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,21 @@ __device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs,
}
}

template <typename T>
__device__ __forceinline__ void ReduceMeanAndVar(T* mean, T* var, T x_mean,
T x_var, int size) {
const int nc = blockIdx.x;
x_mean = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
x_mean, kps::AddFunctor<T>());
x_var = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
x_var, kps::AddFunctor<T>());
__syncthreads();
if (threadIdx.x == 0) {
mean[nc] = static_cast<T>(x_mean / size);
var[nc] = static_cast<T>(x_var / size);
}
}

template <typename T>
__global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) {
int i = blockIdx.x;
Expand All @@ -162,10 +177,7 @@ __global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) {
x_mean += val;
x_var += val * val;
}
x_mean /= size;
x_var /= size;
CudaAtomicAddWithWarp(&mean[i], x_mean);
CudaAtomicAddWithWarp(&var[i], x_var);
ReduceMeanAndVar<T>(mean, var, x_mean, x_var, size);
}

template <typename T, typename AccT, int VecSize>
Expand All @@ -174,21 +186,12 @@ __global__ void VectorizedGetMeanAndVarNCHW(const T* x, T* mean, T* var,
int i = blockIdx.x;
AccT x_mean = static_cast<AccT>(0);
AccT x_var = static_cast<AccT>(0);
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
x += i * size;
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
phi::Array<const T*, 1> ins;
ins[0] = x;
ThreadReduce<T, AccT, VecSize, 1>(ins, size, input_offset, &x_mean, &x_var);

x_mean = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
x_mean, kps::AddFunctor<AccT>());
x_var = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
x_var, kps::AddFunctor<AccT>());
__syncthreads();
if (threadIdx.x == 0) {
mean[i] = static_cast<T>(x_mean / size);
var[i] = static_cast<T>(x_var / size);
}
ReduceMeanAndVar<AccT>(mean, var, x_mean, x_var, size);
}

template <typename T, int flags>
Expand Down Expand Up @@ -272,10 +275,6 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
Tensor temp_var;
temp_var.mutable_data<T>(var->dims(), ctx.GetPlace());

set_zero(dev_ctx, mean, static_cast<T>(0));
set_zero(dev_ctx, &temp_var, static_cast<T>(0));

auto* x_data = x->data<T>();
auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>();
Expand Down Expand Up @@ -319,7 +318,7 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 grids(x_dims[0] * groups);
dim3 blocks(block_size_nchw);
if (size < vec_size) {
if (size < vec_size * block_size_nchw) {
ScalarGetMeanAndVarNCHW<T><<<grids, blocks, 0, dev_ctx.stream()>>>(
x_data, mean_data, temp_var_data, size);
} else {
Expand All @@ -328,6 +327,8 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
x_data, mean_data, temp_var_data, size);
}
} else {
set_zero(dev_ctx, mean, static_cast<T>(0));
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data,
temp_var_data);
Expand Down Expand Up @@ -424,24 +425,15 @@ __global__ void VectorizedGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
int i = blockIdx.x;
AccT ds_sum = static_cast<AccT>(0);
AccT db_sum = static_cast<AccT>(0);
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
x += i * imsize;
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);

phi::Array<const T*, 2> ins;
ins[0] = x;
ins[1] = dy;
ThreadReduce<T, AccT, VecSize, 2>(ins, imsize, input_offset, &db_sum,
&ds_sum);

ds_sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
ds_sum, kps::AddFunctor<AccT>());
db_sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
db_sum, kps::AddFunctor<AccT>());
__syncthreads();
if (threadIdx.x == 0) {
ds[i] = ds_sum;
db[i] = db_sum;
}
ReduceMeanAndVar<AccT>(db, ds, db_sum, ds_sum, 1);
}

template <typename T>
Expand All @@ -455,8 +447,7 @@ __global__ void ScalarGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
ds_sum += dy[index] * x[index];
db_sum += dy[index];
}
CudaAtomicAddWithWarp(&ds[nc], ds_sum);
CudaAtomicAddWithWarp(&db[nc], db_sum);
ReduceMeanAndVar<T>(db, ds, db_sum, ds_sum, 1);
}

template <typename T>
Expand Down Expand Up @@ -641,13 +632,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
}
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 blocks(block_size_nchw);
if (imsize < vec_size) {
if (d_scale) {
set_zero(dev_ctx, d_scale, static_cast<T>(0));
}
if (d_bias) {
set_zero(dev_ctx, d_bias, static_cast<T>(0));
}
if (imsize < vec_size * block_size_nchw) {
ScalarGetDsDbCUDAKernel<
T><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
imsize, x_data, dy_data, ds_data, db_data);
Expand Down Expand Up @@ -687,7 +672,6 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
imsize, C, group_size, groups, p1_data, p2_data, p3_data, x_data,
dy_data, d_x_data);
}

} else {
if (d_scale) {
set_zero(dev_ctx, d_scale, static_cast<T>(0));
Expand Down