Skip to content

Commit

Permalink
Extend DualGemm: support batched mode + decouple B0/B1 layouts (NVIDI…
Browse files Browse the repository at this point in the history
…A#790)

* Fix MHA kernel

Summary:

ATT

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Extend DualGemm to support batched mode (NVIDIA#5)

Following the GemmUniversalMode::kBatched implementation, batched mode is added to the DualGemm (under examples/45_dual_gemm). DualGemmMode::kBatched and SplitKSerial are not compatible: Status::kErrorInvalidProblem is returned if both are set.

* Decouple LayoutB0 and LayoutB1 in DualGemm

The DualGemm template assumed the same layout, LayoutB, for both right operand matrices B0 and B1. This is problematic if the layout of the two matrices is different. In particular, this may be the case when one of the matrices is row-major, while the other is a (column) vector that has to be broadcasted in column-major with zero stride (e.g., as {B1.device_data(), 0}) for the DualGemm implementation to be able to process B0 and B1 simultaneously.

In this commit, LayoutB0 and LayoutB1 are decoupled throughout the DualGemm code (device, kernel, and mma). Additionally, the batch strides of B0 and B1 are also decoupled to accommodate the column vector B1 case described above.

* Remove comment as no longer relevant

* Revert Fix MHA kernel

---------

Co-authored-by: mikeiovine <mikeiovine@fb.com>
  • Loading branch information
aakhundov and mikeiovine authored Feb 13, 2023
1 parent ce8597d commit 3c995c7
Show file tree
Hide file tree
Showing 7 changed files with 793 additions and 305 deletions.
110 changes: 76 additions & 34 deletions examples/45_dual_gemm/device/dual_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ D2 = element_wise(D0, D1)
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"

#include "../kernel/dual_gemm.h"
#include "../dual_gemm_common.h"

////////////////////////////////////////////////////////////////////////////////

Expand All @@ -68,8 +69,10 @@ template <
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Layout type for B0 matrix operand
typename LayoutB0_,
/// Layout type for B1 matrix operand
typename LayoutB1_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
Expand Down Expand Up @@ -119,8 +122,10 @@ class DualGemm {
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using LayoutB0 = LayoutB0_;
using LayoutB1 = LayoutB1_;
using TensorRefB0 = TensorRef<ElementB const, LayoutB0>;
using TensorRefB1 = TensorRef<ElementB const, LayoutB1>;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
Expand Down Expand Up @@ -151,23 +156,31 @@ class DualGemm {
/// Define the threadblock-scoped matrix multiply-accumulate
static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented");
static_assert(kStages >= 3, "Only multistage is implemented");
using Mma = typename cutlass::gemm::threadblock::DefaultMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
using Mma0 = typename cutlass::gemm::threadblock::DefaultMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
ThreadblockShape, WarpShape,
InstructionShape, Stages, Operator>::ThreadblockMma;
using Mma1 = typename cutlass::gemm::threadblock::DefaultMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
ThreadblockShape, WarpShape,
InstructionShape, Stages, Operator>::ThreadblockMma;
using DualMma = threadblock::DualMmaMultistage<
typename Mma::Shape,
typename Mma::IteratorA,
typename Mma::SmemIteratorA,
Mma::kCacheOpA,
typename Mma::IteratorB,
typename Mma::SmemIteratorB,
Mma::kCacheOpB,
typename Mma::ElementC,
typename Mma::LayoutC,
typename Mma::Policy,
Mma::kStages,
typename Mma0::Shape,
typename Mma0::IteratorA,
typename Mma0::SmemIteratorA,
Mma0::kCacheOpA,
typename Mma0::IteratorB,
typename Mma0::SmemIteratorB,
Mma0::kCacheOpB,
typename Mma1::IteratorB,
typename Mma1::SmemIteratorB,
typename Mma0::ElementC,
typename Mma0::LayoutC,
typename Mma0::Policy,
typename Mma1::Policy,
Mma0::kStages,
SharedMemoryClearOption::kNone
>;

Expand All @@ -176,11 +189,11 @@ class DualGemm {
/// Define the epilogue
using Epilogue0 =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp0,
ThreadblockShape, typename DualMma::Operator0, kPartitionsK, EpilogueOutputOp0,
EpilogueOutputOp0::kCount>::Epilogue;
using Epilogue1 =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp1,
ThreadblockShape, typename DualMma::Operator1, kPartitionsK, EpilogueOutputOp1,
EpilogueOutputOp1::kCount>::Epilogue;

/// Define the kernel-level GEMM operator.
Expand All @@ -197,12 +210,13 @@ class DualGemm {
// Data members
//

DualGemmMode mode;
GemmCoord problem_size;
TensorRef<ElementA const, LayoutA> ref_A0;
TensorRef<ElementB const, LayoutB> ref_B0;
TensorRef<ElementB const, LayoutB0> ref_B0;
TensorRef<ElementC const, LayoutC> ref_C0;
TensorRef<ElementC, LayoutC> ref_D0;
TensorRef<ElementB const, LayoutB> ref_B1;
TensorRef<ElementB const, LayoutB1> ref_B1;
TensorRef<ElementC const, LayoutC> ref_C1;
TensorRef<ElementC, LayoutC> ref_D1;
TensorRef<ElementC, LayoutC> ref_D2;
Expand All @@ -211,6 +225,13 @@ class DualGemm {
typename EpilogueOutputOp2::Params epilogue2;
int split_k_slices;

int batch_count;
int64_t batch_stride_A;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C;
int64_t batch_stride_D;

//
// Methods
//
Expand All @@ -224,12 +245,13 @@ class DualGemm {
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
DualGemmMode mode,
GemmCoord problem_size_,
TensorRef<ElementA const, LayoutA> ref_A0_,
TensorRef<ElementB const, LayoutB> ref_B0_,
TensorRef<ElementB const, LayoutB0> ref_B0_,
TensorRef<ElementC const, LayoutC> ref_C0_,
TensorRef<ElementC, LayoutC> ref_D0_,
TensorRef<ElementB const, LayoutB> ref_B1_,
TensorRef<ElementB const, LayoutB1> ref_B1_,
TensorRef<ElementC const, LayoutC> ref_C1_,
TensorRef<ElementC, LayoutC> ref_D1_,
TensorRef<ElementC, LayoutC> ref_D2_,
Expand All @@ -239,8 +261,15 @@ class DualGemm {
typename EpilogueOutputOp1::Params(),
typename EpilogueOutputOp2::Params epilogue2_ =
typename EpilogueOutputOp2::Params(),
int split_k_slices_ = 1
int split_k_slices_ = 1,
int batch_count = 1,
int64_t batch_stride_A = 0,
int64_t batch_stride_B0 = 0,
int64_t batch_stride_B1 = 0,
int64_t batch_stride_C = 0,
int64_t batch_stride_D = 0
):
mode(mode),
problem_size(problem_size_),
ref_A0(ref_A0_),
ref_B0(ref_B0_),
Expand All @@ -253,7 +282,13 @@ class DualGemm {
epilogue0(epilogue0_),
epilogue1(epilogue1_),
epilogue2(epilogue2_),
split_k_slices(split_k_slices_) {
split_k_slices(split_k_slices_),
batch_count(batch_count),
batch_stride_A(batch_stride_A),
batch_stride_B0(batch_stride_B0),
batch_stride_B1(batch_stride_B1),
batch_stride_C(batch_stride_C),
batch_stride_D(batch_stride_D) {

}
};
Expand All @@ -271,6 +306,9 @@ class DualGemm {
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {

if (args.mode == DualGemmMode::kBatched && kSplitKSerial) {
return Status::kErrorInvalidProblem;
}
if (!kSplitKSerial && args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
Expand Down Expand Up @@ -304,17 +342,15 @@ class DualGemm {
static size_t get_workspace_size(Arguments const &args) {

size_t bytes = 0;

// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;

cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);

if (kSplitKSerial && args.split_k_slices > 1) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;

cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);

bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
Expand All @@ -331,7 +367,7 @@ class DualGemm {
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);
args.mode == DualGemmMode::kBatched ? args.batch_count : args.split_k_slices);

if (kSplitKSerial) {
if (args.split_k_slices > 1) {
Expand All @@ -357,6 +393,7 @@ class DualGemm {

// Initialize the Params structure
params_ = typename DualGemmKernel::Params{
args.mode,
args.problem_size,
grid_shape,
args.ref_A0.non_const_ref(),
Expand All @@ -371,6 +408,11 @@ class DualGemm {
args.epilogue1,
args.epilogue2,
reinterpret_cast<int *>(workspace),
args.batch_stride_A,
args.batch_stride_B0,
args.batch_stride_B1,
args.batch_stride_C,
args.batch_stride_D,
};

return Status::kSuccess;
Expand Down
Loading

0 comments on commit 3c995c7

Please sign in to comment.