Skip to content

Commit

Permalink
Softmax (NVIDIA#546)
Browse files Browse the repository at this point in the history
* add test layernorm g-mem version

* Delete include/configure directory

* Delete examples/test_layernorm directory

* Update gemm_with_softmax.h

* Update gemm_softmax.cu

* Update linear_combination.h

* Update fast_math.h

* remove redundant vars

Co-authored-by: yujia.zhai <yujia.zhai@bytedance.com>
Co-authored-by: yuzhai <yuzhai@nvidia.com>
  • Loading branch information
3 people authored Jul 2, 2022
1 parent e45e773 commit 04a9777
Show file tree
Hide file tree
Showing 4 changed files with 407 additions and 355 deletions.
62 changes: 39 additions & 23 deletions examples/35_gemm_softmax/gemm_softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "cutlass/util/reference/host/error_metrics.h"
#include "cutlass/util/tensor_view_io.h"

#include "cutlass/epilogue/thread/linear_combination.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

#include "gemm_with_softmax.h"
Expand Down Expand Up @@ -204,14 +205,24 @@ struct Testbed {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;

/// Linear scaling operator
using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination<
ElementC,
128 / cutlass::sizeof_bits<ElementC>::value,
ElementCompute,
ElementCompute
>;

using GemmSoftmax = cutlass::GemmSoftmax<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC,
ElementCompute
ElementCompute,
EpilogueFunctorOp
>;

using ElementN = typename GemmSoftmax::ElementN;
using ElementNorm = typename GemmSoftmax::ElementNorm;
using ElementSum = typename GemmSoftmax::ElementSum;
using LayoutC = typename GemmSoftmax::LayoutC;

//
Expand All @@ -224,13 +235,16 @@ struct Testbed {
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
cutlass::HostTensor<ElementD, LayoutC> tensor_D;
cutlass::HostTensor<ElementN, LayoutC> tensor_N;
cutlass::HostTensor<ElementNorm, LayoutC> tensor_N;
cutlass::HostTensor<ElementSum, LayoutC> tensor_S;
cutlass::HostTensor<ElementSoftmax, LayoutC> tensor_Softmax;

cutlass::HostTensor<ElementD, LayoutC> reference_D;
cutlass::HostTensor<ElementN, LayoutC> reference_N;
cutlass::HostTensor<ElementNorm, LayoutC> reference_N;
cutlass::HostTensor<ElementSoftmax, LayoutC> reference_Softmax;

int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN;

//
// Methods
//
Expand All @@ -247,7 +261,8 @@ struct Testbed {
tensor_C.reset({options.problem_size.m(), options.problem_size.n()});
tensor_D.reset({options.problem_size.m(), options.problem_size.n()});

tensor_N.reset({options.problem_size.m(), 1});
tensor_N.reset({block_num, options.problem_size.m()});
tensor_S.reset({block_num, options.problem_size.m()});
tensor_Softmax.reset({options.problem_size.m(), options.problem_size.n()});

reference_D.reset({options.problem_size.m(), options.problem_size.n()}, false);
Expand Down Expand Up @@ -342,7 +357,7 @@ struct Testbed {

cutlass::reference::host::TensorFill(
reference_N.host_view(),
ElementN()
ElementNorm()
);

cutlass::reference::host::TensorFill(
Expand All @@ -354,6 +369,7 @@ struct Testbed {
tensor_B.sync_device();
tensor_D.sync_device();
tensor_N.sync_device();
tensor_S.sync_device();
tensor_Softmax.sync_device();
}

Expand All @@ -377,6 +393,7 @@ struct Testbed {
ElementCompute(options.beta)
},
tensor_N.device_ref(),
tensor_S.device_ref(),
tensor_Softmax.device_ref()
);

Expand Down Expand Up @@ -420,7 +437,7 @@ struct Testbed {
for (int m = 0; m < options.problem_size.m(); ++m) {
reference_N.at({m, 0}) = reference_D.at({m, 0});
for (int n = 1; n < options.problem_size.n(); ++n) {
reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementN(reference_D.at({m, n})));
reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(reference_D.at({m, n})));
}
}

Expand Down Expand Up @@ -454,6 +471,20 @@ struct Testbed {
std::cout << "Reference Softmax = \n" << reference_Softmax.host_view() << "\n\n";
}

bool verify_tensor_N(cutlass::HostTensor<ElementNorm, LayoutC> tensor_N, \
cutlass::HostTensor<ElementNorm, LayoutC> reference_N) {

for (int m = 0; m < options.problem_size.m(); ++m) {
float diff = (float)(tensor_N.at({0, m}) - reference_N.at({m, 0}));
if (fabs(diff) > options.tolerance) {
return false;
}

}

return true;
}

/// Verifies the reference matches
bool verify() {

Expand Down Expand Up @@ -489,22 +520,7 @@ struct Testbed {
}

if (!verified_N) {

double norm_diff = cutlass::reference::host::TensorNormDiff(
tensor_N.host_view(),
reference_N.host_view());

double norm_reference = cutlass::reference::host::TensorNorm(
reference_N.host_view());

double rel_error = norm_diff / norm_reference;

if (rel_error > kThreshold) {
std::cerr << "\n\nTensor N Relative error: " << rel_error << std::endl;
}
else {
verified_N = true;
}
verified_N = verify_tensor_N(tensor_N, reference_N);
}

if (!verified_Softmax) {
Expand Down
Loading

0 comments on commit 04a9777

Please sign in to comment.