Skip to content

Commit

Permalink
Add residual support for shmem staging iterator used in back-to-back …
Browse files Browse the repository at this point in the history
…GEMM fusion. This allows support of problem_size_0_n that is not multiple of 32. (NVIDIA#590)

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
  • Loading branch information
hwu36 and hwu36 authored Aug 15, 2022
1 parent e66bfcb commit 497b499
Show file tree
Hide file tree
Showing 8 changed files with 462 additions and 35 deletions.
2 changes: 1 addition & 1 deletion examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ struct B2bGemm {
OutputOp0 output_op_0(params.output_op_0);

// Construct thread-scoped matrix multiply
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());

typename B2bMma::FragmentC0 src_accum;
typename B2bMma::FragmentC1 accumulators;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ class B2bMmaMultistage :
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
int lane_idx,
///< GEMM0 N is used for accumulator extent
int problem_size_0_n
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
Expand Down Expand Up @@ -639,7 +641,6 @@ class B2bMmaMultistage :

}


// 2nd Gemm

/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
Expand All @@ -657,12 +658,11 @@ class B2bMmaMultistage :
tb_frag_A1_bias.clear();
iterator_A1_bias.load(tb_frag_A1_bias);
++iterator_A1_bias;



//
// Prologue
//
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
int gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;

// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
Expand Down Expand Up @@ -750,9 +750,9 @@ class B2bMmaMultistage :
// Mainloop
//

gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1 - (Base::kStages - 1);
CUTLASS_PRAGMA_UNROLL
for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1);
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
for (; gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
//
// Loop over GEMM K dimension
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,15 @@ class B2bMmaMultistageSmemAccumulator :
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
int lane_idx,
///< GEMM0 N is used for accumulator extent
int problem_size_0_n
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx ),
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ class B2bMmaPipelined :
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
int lane_idx, ///< ID of each thread within a warp
int problem_size_0_n ///< GEMM0 N is used for accumulator extent
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,14 @@ class B2bMmaPipelinedSmemAccumulator :
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
int lane_idx, ///< ID of each thread within a warp
int problem_size_0_n ///< GEMM0 N is used for accumulator extent
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx),
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) {

// Compute warp location within threadblock tile by mapping the warp_id to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"

#include "threadblock/b2b_mma_pipelined_smem_accumulator.h"
#include "threadblock/b2b_mma_multistage_smem_accumulator.h"
Expand Down Expand Up @@ -158,11 +158,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,

static int const kThreadCount = 32;
// load warp tile from Shared Memory accumulator
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
ElementA, SmemAccumulatorLayout,
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>;

// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
Expand Down Expand Up @@ -303,11 +303,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,

static int const kThreadCount = 32;
// load warp tile from Shared Memory accumulator
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
ElementA, SmemAccumulatorLayout,
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>;

// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator<
Expand Down Expand Up @@ -436,11 +436,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,

static int const kThreadCount = 32;
// load warp tile from Shared Memory accumulator
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
ElementA, SmemAccumulatorLayout,
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>;

// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
Expand Down Expand Up @@ -574,11 +574,11 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,

static int const kThreadCount = 32;
// load warp tile from Shared Memory accumulator
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
MatrixShape<WarpShape1::kM, WarpShape1::kK>, cutlass::gemm::Operand::kA,
ElementA, SmemAccumulatorLayout,
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true >;


// Define the threadblock-scoped multistage matrix multiply
Expand Down
Loading

0 comments on commit 497b499

Please sign in to comment.