Skip to content

Commit

Permalink
Add grouped b2b GEMM (NVIDIA#970)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkosaian authored Jun 5, 2023
1 parent fde824a commit 87349d3
Show file tree
Hide file tree
Showing 15 changed files with 1,644 additions and 107 deletions.
1 change: 1 addition & 0 deletions examples/13_two_tensor_op_fusion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ endforeach()
foreach(FUSION_GEMM_EXAMPLE
fused_two_gemms_f16_sm75_rf
fused_two_gemms_f16_sm75_shmem
fused_two_gemms_grouped_f16_sm80_rf
fused_two_gemms_f16_sm80_rf
fused_two_gemms_f16_sm80_shmem
fused_two_gemms_s8_sm75_rf
Expand Down
450 changes: 450 additions & 0 deletions examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h

Large diffs are not rendered by default.

91 changes: 1 addition & 90 deletions examples/13_two_tensor_op_fusion/device/b2b_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,96 +185,7 @@ class B2bGemm {
SmemAccumulator
>::B2bGemmKernel;

/// Argument structure
struct Arguments {

//
// Data members
//

GemmUniversalMode mode;
GemmCoord problem_size_0;
GemmCoord problem_size_1;
TensorRef<ElementA const, LayoutA> ref_A0;
TensorRef<ElementB const, LayoutB> ref_B0;
TensorRef<ElementC const, LayoutC> ref_C0;
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
TensorRef<ElementB const, LayoutB> ref_B1;
TensorRef<ElementC const, LayoutC> ref_C1;
TensorRef<ElementC, LayoutC> ref_D1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
typename EpilogueOutputOp0::Params epilogue0;
typename EpilogueOutputOp1::Params epilogue1;
int batch_count;

//
// Methods
//

/// Default ctor
CUTLASS_HOST_DEVICE
Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {

}

/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmUniversalMode mode_,
GemmCoord problem_size_0_,
GemmCoord problem_size_1_,
TensorRef<ElementA const, LayoutA> ref_A0_,
TensorRef<ElementB const, LayoutB> ref_B0_,
TensorRef<ElementC const, LayoutC> ref_C0_,
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
TensorRef<ElementB const, LayoutB> ref_B1_,
TensorRef<ElementC const, LayoutC> ref_C1_,
TensorRef<ElementC, LayoutC> ref_D1_,
int64_t batch_stride_A0_,
int64_t batch_stride_B0_,
int64_t batch_stride_B1_,
int64_t batch_stride_C1_,
int64_t batch_stride_D1_,
int64_t batch_stride_Bias0_,
int64_t batch_stride_Scale0_,
typename EpilogueOutputOp0::Params epilogue0_ =
typename EpilogueOutputOp0::Params(),
typename EpilogueOutputOp1::Params epilogue1_ =
typename EpilogueOutputOp1::Params(),
int batch_count_ = 1
):
mode(mode_),
problem_size_0(problem_size_0_),
problem_size_1(problem_size_1_),
ref_A0(ref_A0_),
ref_B0(ref_B0_),
ref_C0(ref_C0_),
ref_Scale0(ref_Scale0_),
ref_Bias0(ref_Bias0_),
ref_B1(ref_B1_),
ref_C1(ref_C1_),
ref_D1(ref_D1_),
batch_stride_A0(batch_stride_A0_),
batch_stride_B0(batch_stride_B0_),
batch_stride_B1(batch_stride_B1_),
batch_stride_C1(batch_stride_C1_),
batch_stride_D1(batch_stride_D1_),
batch_stride_Bias0(batch_stride_Bias0_),
batch_stride_Scale0(batch_stride_Scale0_),
epilogue0(epilogue0_),
epilogue1(epilogue1_),
batch_count(batch_count_) {

}
};
using Arguments = typename B2bGemmKernel::Arguments;

private:

Expand Down
Loading

0 comments on commit 87349d3

Please sign in to comment.