Skip to content

Commit

Permalink
fix race condition when h < stride_h or w < stride_w (NVIDIA#562)
Browse files Browse the repository at this point in the history
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
  • Loading branch information
hwu36 and hwu36 authored Jul 12, 2022
1 parent fb379ea commit e7a61c7
Showing 1 changed file with 20 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,26 @@ struct ImplicitGemmConvolutionStridedDgrad {
int start_r, start_s;
params.stride_w_divmod(start_r, start_s, filter_tile_m);

int filter_r = start_r;
int filter_s = start_s;

if (params.problem_size.mode == Mode::kConvolution) {
filter_r = (params.problem_size.R - 1 - filter_r);
filter_s = (params.problem_size.S - 1 - filter_s);
}

// Starting h, w positions for filter position in gemm_k=0
int start_h, start_w;
strided_dgrad_starting_coords(
params.problem_size,
params.stride_h_divmod, params.stride_w_divmod,
filter_r, filter_s,
start_h, start_w);

if (start_h >= params.problem_size.H || start_w >= params.problem_size.W) {
return;
}

typename Mma::FragmentC accumulators;

accumulators.clear();
Expand Down

0 comments on commit e7a61c7

Please sign in to comment.