Skip to content

Commit

Permalink
upstream internal updates (NVIDIA#616)
Browse files Browse the repository at this point in the history
Co-authored-by: yuzhai <yuzhai@nvidia.com>
  • Loading branch information
yzhaiustc and yuzhai authored Sep 5, 2022
1 parent b72cbf9 commit b1d3f9b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
11 changes: 8 additions & 3 deletions examples/35_gemm_softmax/gemm_softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ struct Testbed {
using OperatorClass = cutlass::arch::OpClassTensorOp;
using ArchTag = cutlass::arch::Sm80;

// ApplyShape impacts the final Softmax performance a lot.
// Set ApplyShape::kColumn to be the next multiple of 32 number that is after
// (gemm_N / alignment).
// Set ApplyShape::kRow to max(1, 128 / ApplyShape::kColumn).
using ApplyShape = cutlass::MatrixShape<1, 1024>;

static int const kStages = 3;

/// Linear scaling operator
Expand All @@ -239,7 +245,8 @@ struct Testbed {
WarpShape,
InstructionShape,
EpilogueFunctorOp,
kStages
kStages,
ApplyShape
>;

using ElementNorm = typename GemmSoftmax::ElementNorm;
Expand Down Expand Up @@ -710,6 +717,4 @@ int main(int argc, const char **argv) {
return (disposition == Disposition::kPassed ? 0 : -1);
}


/////////////////////////////////////////////////////////////////////////////////////////////////

37 changes: 18 additions & 19 deletions examples/35_gemm_softmax/gemm_with_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ template <
typename ElementSoft_,
typename ElementSoftmaxCompute_,
int Alignment,
typename Shape_ = MatrixShape<4, 16>
typename ApplyShape_ = MatrixShape<1, 1024>
>
class ApplySoftmax {
public:
Expand All @@ -91,7 +91,7 @@ class ApplySoftmax {
using ElementSoftmaxCompute = ElementSoftmaxCompute_;

static int const kAlignment = Alignment;
using Shape = Shape_;
using ApplyShape = ApplyShape_;

using Layout = cutlass::layout::RowMajor;

Expand Down Expand Up @@ -202,7 +202,7 @@ class ApplySoftmax {
using AccessTypeD = AlignedArray<ElementD, kAlignment>;

int block_batch = blockIdx.z;
int block_m = blockIdx.x * Shape::kRow;
int block_m = blockIdx.x * ApplyShape::kRow;
int block_n = 0;

int thread_m = threadIdx.y;
Expand Down Expand Up @@ -256,8 +256,8 @@ class ApplySoftmax {
params.args.batch_stride_Soft * block_batch +
params.args.ref_Soft.layout()({idx_m, idx_n}));

ElementSum inv_sum = (params.args.ref_S.data())[block_m + batch_offset_sum];
ElementNorm norm = (params.args.ref_N.data())[block_m + batch_offset_norm];
ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum];
ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm];

//
// Loop
Expand All @@ -266,10 +266,9 @@ class ApplySoftmax {
for (
int idx = 0;
idx < params.args.extent.column();
idx += Shape::kColumn * kAlignment) {
idx += ApplyShape::kColumn * kAlignment) {

if (idx_n < params.args.extent.column()) {

AccessTypeD fetch;
arch::global_load<AccessTypeD, sizeof(AccessTypeD)>(fetch, access_d, true);

Expand All @@ -279,14 +278,13 @@ class ApplySoftmax {
arch::global_store<FragmentSoft, sizeof(FragmentSoft)>(soft, access_soft, true);
}

access_d += Shape::kColumn;
access_soft += Shape::kColumn;
idx_n += Shape::kColumn * kAlignment;
access_d += ApplyShape::kColumn;
access_soft += ApplyShape::kColumn;
idx_n += ApplyShape::kColumn * kAlignment;
}
}
};


/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace kernel
Expand All @@ -308,6 +306,7 @@ template <
typename InstructionShape_,
typename EpilogueFunctorOp_,
int kStages_,
typename ApplyShape_ = MatrixShape<1, 1024>,
int AlignmentA_ = 128 / cutlass::sizeof_bits<ElementA_>::value,
int AlignmentB_ = 128 / cutlass::sizeof_bits<ElementB_>::value,
int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits<ElementC_>::value,
Expand Down Expand Up @@ -338,6 +337,8 @@ class GemmSoftmax {
using EpilogueFunctorOp = EpilogueFunctorOp_;
using ElementNorm = ElementNorm_;

using ApplyShape = ApplyShape_;

// These are mandatory layouts.
using LayoutC = cutlass::layout::RowMajor;
using LayoutN = cutlass::layout::RowMajor;
Expand Down Expand Up @@ -427,9 +428,7 @@ class GemmSoftmax {
ElementSoft,
ElementSoftmaxCompute,
AlignmentSoftmax,
MatrixShape<
1, 1024
>
ApplyShape
>;

using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
Expand Down Expand Up @@ -616,14 +615,14 @@ class GemmSoftmax {
// Launch the SoftmaxApplyKernel
//

dim3 apply_block(SoftmaxApplyKernel::Shape::kColumn, SoftmaxApplyKernel::Shape::kRow);
dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow);

int cta_rows = SoftmaxApplyKernel::Shape::kRow;
int cta_columns = SoftmaxApplyKernel::Shape::kColumn * SoftmaxApplyKernel::kAlignment;
int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow;
int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment;

dim3 apply_grid(
(params_.softmax.args.extent.row() + cta_rows - 1) / cta_rows,
(params_.softmax.args.extent.column() + cta_columns - 1) / cta_columns,
(params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows,
(params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns,
params_.softmax.args.batch_count);

Kernel<SoftmaxApplyKernel><<<
Expand Down

0 comments on commit b1d3f9b

Please sign in to comment.