Skip to content

Commit

Permalink
[src] Change warp-synchronous to cub::BlockReduce (safer but slower) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
desh2608 authored and danpovey committed Mar 10, 2019
1 parent 2f95609 commit bcfe3f8
Showing 1 changed file with 38 additions and 133 deletions.
171 changes: 38 additions & 133 deletions src/cudamatrix/cu-kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <limits>
#include <math_constants.h>
#include "cudamatrix/cu-kernels-ansi.h"

#include <cub/block/block_reduce.cuh>


/***********************************************************************
Expand Down Expand Up @@ -958,6 +958,7 @@ static void _trace_mat_mat(const Real* A, const Real* B, MatrixDim dA,
Real trans[TileDim][TileDim + 1];
Real sum[CU1DBLOCK];
} smem;

// linear thread id;
const int32_cuda tid = threadIdx.y * blockDim.x + threadIdx.x;
const int32_cuda grid_height = gridDim.y * TileDim;
Expand Down Expand Up @@ -1021,6 +1022,7 @@ static void _trace_mat_mat(const Real* A, const Real* B, MatrixDim dA,
if (tid == 0) {
value[blockIdx.y * gridDim.x + blockIdx.x] = smem.sum[0];
}

}

// _trace_mat_mat_trans reduce the partial sum to
Expand All @@ -1030,6 +1032,7 @@ __global__
static void _trace_mat_mat_trans(const Real* A, const Real* B, MatrixDim dA,
int B_stride, Real* value) {
__shared__ Real ssum[CU1DBLOCK];

// linear thread id;
const int32_cuda tid = threadIdx.y * blockDim.x + threadIdx.x;
const int32_cuda j = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -1046,7 +1049,7 @@ static void _trace_mat_mat_trans(const Real* A, const Real* B, MatrixDim dA,
}
ssum[tid] = tsum;
__syncthreads();

// Block reduce
# pragma unroll
for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) {
Expand Down Expand Up @@ -2485,6 +2488,8 @@ template<typename Real>
__global__
static void _softmax_reduce(Real*y, const Real*x, MatrixDim d, int src_stride) {
__shared__ Real smem[CU1DBLOCK];
typedef cub::BlockReduce<Real, CU1DBLOCK> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp_storage;
const int i = blockIdx.x;
const int x_start = i * src_stride;
const int y_start = i * d.stride;
Expand All @@ -2496,24 +2501,9 @@ static void _softmax_reduce(Real*y, const Real*x, MatrixDim d, int src_stride) {
for (int j = tid; j < d.cols; j += CU1DBLOCK) {
tmax = fmax(tmax, x[x_start + j]);
}
smem[tid] = tmax;
__syncthreads();

// reduce to 2x warpSize elements per row
# pragma unroll
for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) {
if (tid < shift) {
smem[tid] = fmax(smem[tid], smem[tid + shift]);
}
__syncthreads();
}

// reduce to 1 element per row
if (tid < warpSize) {
# pragma unroll
for (int shift = warpSize; shift > 0; shift >>= 1) {
smem[tid] = fmax(smem[tid], smem[tid + shift]);
}
tmax = BlockReduceT(temp_storage).Reduce(tmax, cub::Max());
if (tid == 0) {
smem[0] = tmax;
}

// broadcast max to all threads
Expand All @@ -2526,24 +2516,9 @@ static void _softmax_reduce(Real*y, const Real*x, MatrixDim d, int src_stride) {
for (int j = tid; j < d.cols; j += CU1DBLOCK) {
tsum += exp(x[x_start + j] - max);
}
smem[tid] = tsum;
__syncthreads();

// reduce to 2x warpSize elements per row
# pragma unroll
for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) {
if (tid < shift) {
smem[tid] += smem[tid + shift];
}
__syncthreads();
}

// reduce to 1 element per row
if (tid < warpSize) {
# pragma unroll
for (int shift = warpSize; shift > 0; shift >>= 1) {
smem[tid] += smem[tid + shift];
}
tsum = BlockReduceT(temp_storage).Sum(tsum);
if (tid == 0) {
smem[0] = tsum;
}

// broadcast sum to all threads
Expand Down Expand Up @@ -2577,41 +2552,23 @@ static void _normalize_per_row(Real *y, int y_stride, const Real *x,
const int i = blockIdx.x;
const int tid = threadIdx.x;
const Real* x_row = x + i * x_d.stride;
typedef cub::BlockReduce<Real, CU1DBLOCK> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp_storage;
__shared__ Real ssum[CU1DBLOCK];

// Reduce x_j^2 to CU1DBLOCK elements per row
Real tsum = Real(0);
for (int j = tid; j < x_d.cols; j += CU1DBLOCK) {
tsum += x_row[j] * x_row[j];
}
ssum[tid] = tsum;
tsum = BlockReduceT(temp_storage).Sum(tsum);
__syncthreads();

// Tree reduce to 2x warpSize elements per row
# pragma unroll
for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) {
if (tid < shift)
ssum[tid] += ssum[tid + shift];
__syncthreads();
}

// Reduce last warp to 1 element per row.
// Threads implicitly synchronized within a warp.
if (tid < warpSize) {
# pragma unroll
for (int shift = warpSize; shift > 0; shift >>= 1) {
ssum[tid] += ssum[tid + shift];
}
}


const Real kSquaredNormFloor = 1.3552527156068805425e-20; // 2^-66
if (tid == 0) {
ssum[0] = sqrt(
fmax(ssum[0] / (target_rms * target_rms * x_d.cols), kSquaredNormFloor));
}
ssum[tid] = sqrt(
fmax(tsum / (target_rms * target_rms * x_d.cols), kSquaredNormFloor));

// Broadcast floored stddev to all threads.
__syncthreads();
const Real stddev_div_target_rms = ssum[0];
const Real scale = Real(1) / stddev_div_target_rms;

Expand All @@ -2626,7 +2583,6 @@ static void _normalize_per_row(Real *y, int y_stride, const Real *x,
}
}


template<typename Real>
__global__
static void _diff_normalize_per_row(Real *id, int id_stride, const Real *iv,
Expand Down Expand Up @@ -2722,6 +2678,8 @@ __global__
static void _log_softmax_reduce(Real* y, const Real* x, MatrixDim y_dim,
int x_stride) {
__shared__ Real smem[CU1DBLOCK];
typedef cub::BlockReduce<Real, CU1DBLOCK> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp_storage;
const int i = blockIdx.x;
const int x_start = i * x_stride;
const int y_start = i * y_dim.stride;
Expand All @@ -2733,23 +2691,9 @@ static void _log_softmax_reduce(Real* y, const Real* x, MatrixDim y_dim,
for (int j = tid; j < y_dim.cols; j += CU1DBLOCK) {
tmax = fmax(tmax, x[x_start + j]);
}
smem[tid] = tmax;
__syncthreads();

// reduce to 2x warpSize elements per row
# pragma unroll
for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) {
if (tid < shift) {
smem[tid] = fmax(smem[tid], smem[tid + shift]);
}
__syncthreads();
}

// reduce to 1 element per row
if (tid < warpSize) {
for (int shift = warpSize; shift > 0; shift >>= 1) {
smem[tid] = fmax(smem[tid], smem[tid + shift]);
}
tmax = BlockReduceT(temp_storage).Reduce(tmax, cub::Max());
if (tid == 0) {
smem[0] = tmax;
}

// broadcast max to all threads
Expand All @@ -2762,23 +2706,9 @@ static void _log_softmax_reduce(Real* y, const Real* x, MatrixDim y_dim,
for (int j = tid; j < y_dim.cols; j += CU1DBLOCK) {
tsum += exp(x[x_start + j] - max);
}
smem[tid] = tsum;
__syncthreads();

// reduce to 2x warpSize elements per row
# pragma unroll
for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) {
if (tid < shift) {
smem[tid] += smem[tid + shift];
}
__syncthreads();
}

// reduce to 1 element per row
if (tid < warpSize) {
for (int shift = warpSize; shift > 0; shift >>= 1) {
smem[tid] += smem[tid + shift];
}
tsum = BlockReduceT(temp_storage).Sum(tsum);
if (tid == 0) {
smem[0] = tsum;
}

// broadcast sum to all threads
Expand Down Expand Up @@ -3024,6 +2954,9 @@ static void _diff_softmax(Real* x, const MatrixDim dim, const Real* value,
const int value_stride, const Real* diff,
const int diff_stride) {
__shared__ Real ssum[CU1DBLOCK];
typedef cub::BlockReduce<Real, CU1DBLOCK> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp_storage;

const int tid = threadIdx.x;
const int i = blockIdx.x;
const int value_start = i * value_stride;
Expand All @@ -3035,24 +2968,9 @@ static void _diff_softmax(Real* x, const MatrixDim dim, const Real* value,
for (int j = tid; j < dim.cols; j += CU1DBLOCK) {
tsum += value[value_start + j] * diff[diff_start + j];
}
ssum[tid] = tsum;
__syncthreads();

// Tree reduce to 2x warpSize elements.
# pragma unroll
for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) {
if (tid < shift) {
ssum[tid] += ssum[tid + shift];
}
__syncthreads();
}

// Warp reduce to 1 element. Threads implicitly synchronized within a warp.
if (tid < warpSize) {
# pragma unroll
for (int shift = warpSize; shift > 0; shift >>= 1) {
ssum[tid] += ssum[tid + shift];
}
tsum = BlockReduceT(temp_storage).Sum(tsum);
if (tid == 0) {
ssum[0] = tsum;
}

// Broadcast result to all threads
Expand All @@ -3078,6 +2996,8 @@ static void _diff_log_softmax(const MatrixDim in_deriv_dim,
Real* in_deriv) {

__shared__ Real ssum[CU1DBLOCK];
typedef cub::BlockReduce<Real, CU1DBLOCK> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp_storage;
const int tid = threadIdx.x;
const int i = blockIdx.x;
const int out_value_start = i * out_value_stride;
Expand All @@ -3089,24 +3009,9 @@ static void _diff_log_softmax(const MatrixDim in_deriv_dim,
for (int j = tid; j < in_deriv_dim.cols; j += CU1DBLOCK) {
tsum += out_deriv[out_deriv_start + j];
}
ssum[tid] = tsum;
__syncthreads();

// Tree reduce to 2x warpSize elements.
# pragma unroll
for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) {
if (tid < shift) {
ssum[tid] += ssum[tid + shift];
}
__syncthreads();
}

// Warp reduce to 1 element. Threads implicitly synchronized within a warp.
if (tid < warpSize) {
# pragma unroll
for (int shift = warpSize; shift > 0; shift >>= 1) {
ssum[tid] += ssum[tid + shift];
}
tsum = BlockReduceT(temp_storage).Sum(tsum);
if (tid == 0) {
ssum[0] = tsum;
}

// Broadcast result to all threads
Expand Down

0 comments on commit bcfe3f8

Please sign in to comment.