Skip to content

Commit

Permalink
Merge pull request #12920 from nluehr/cuda9-internal-error-fix
Browse files Browse the repository at this point in the history
Workaround for NVCC 9.0 internal error
  • Loading branch information
sb2nov authored Sep 26, 2017
2 parents f4f5966 + db12753 commit e5ba515
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions tensorflow/core/kernels/reduction_gpu_kernels.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ __global__ void ColumnReduceMax16ColumnsKernel(
if (row * num_cols + col < num_rows * num_cols)
sum = in[row * num_cols + col];

__shared__ value_type partial_sums[32][33];
// 1D array necessary due to bug in CUDA 9 compiler.
// TODO(nluehr) revert to 2D array when compiler is ready.
__shared__ value_type partial_sums[32 * 33];

row += rows_per_warp * gridDim.y * blockDim.y;
for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) {
Expand All @@ -283,16 +285,16 @@ __global__ void ColumnReduceMax16ColumnsKernel(
if (lane < num_cols) sum = op(sum, tmp);
}

if (lane < num_cols) partial_sums[lane][threadIdx.y] = sum;
if (lane < num_cols) partial_sums[lane * 33 + threadIdx.y] = sum;

__syncthreads();

if (threadIdx.y == 0 && threadIdx.x < num_cols) {
value_type s = partial_sums[threadIdx.x][0];
value_type s = partial_sums[threadIdx.x * 33];

if (blockDim.y > 1) {
for (int row = 1; row < blockDim.y; ++row) {
s = op(s, partial_sums[threadIdx.x][row]);
s = op(s, partial_sums[threadIdx.x * 33 + row]);
}
}

Expand All @@ -313,7 +315,9 @@ __global__ void ColumnReduceKernel(
if (row < num_rows && col < num_cols)
sum = in[row * num_cols + col];

__shared__ value_type partial_sums[32][33];
// 1D array necessary due to bug in CUDA 9 compiler.
// TODO(nluehr) revert to 2D array when compiler is ready.
__shared__ value_type partial_sums[32 * 33];

row += gridDim.y * blockDim.y;

Expand All @@ -323,12 +327,12 @@ __global__ void ColumnReduceKernel(
}
}

partial_sums[threadIdx.x][threadIdx.y] = sum;
partial_sums[threadIdx.x * 33 + threadIdx.y] = sum;

__syncthreads();

if (threadIdx.y == 0 && col < num_cols) {
value_type s = partial_sums[threadIdx.x][0];
value_type s = partial_sums[threadIdx.x * 33];

// only include input values in the reduction
// elem block_rows
Expand All @@ -344,7 +348,7 @@ __global__ void ColumnReduceKernel(
min(blockDim.y, num_rows - blockIdx.y * blockDim.y);

for (int row = 1; row < numRowsThisBlock; ++row) {
s = op(s, partial_sums[threadIdx.x][row]);
s = op(s, partial_sums[threadIdx.x * 33 + row]);
}

out[col * gridDim.y + blockIdx.y] = s;
Expand Down

0 comments on commit e5ba515

Please sign in to comment.