Skip to content

Commit

Permalink
Allow per-column bias in EpilogueTensorBroadcast (NVIDIA#1275)
Browse files Browse the repository at this point in the history
* Allow per-column bias in EpilogueTensorBroadcast

EpilogueTensorBroadcast only supports per-row vector broadcast, because
the bias stride is hardcoded.

It can easily support both if the bias stride is made conditional, and
the original behavior is maintained by defaulting to per-row.

* Add unit test for EpilogueTensorBroadcast with per-col bias

---------

Co-authored-by: Ali Hassani <ahassanijr@gmail.com>
Co-authored-by: Ali Hassani <ali@hippoml.com>
  • Loading branch information
3 people authored Jan 4, 2024
1 parent c9591a6 commit d4be5ab
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ template <
class StrideC_,
class StrideD_,
class ThreadEpilogueOp_,
class EpilogueSchedule_
class EpilogueSchedule_,
bool PerColumnBias_ = false
>
class EpilogueTensorBroadcast {
public:
Expand Down Expand Up @@ -101,6 +102,9 @@ class EpilogueTensorBroadcast {
static constexpr bool IsBinaryOp1Enabled = ThreadEpilogueOp::IsBinaryOp1Enabled;
static constexpr bool IsUnaryOpEnabled = ThreadEpilogueOp::IsUnaryOpEnabled;

static constexpr bool PerColumnBias = PerColumnBias_;
using BiasStride = typename cute::conditional_t<PerColumnBias, Stride<_0, _1, _0>, Stride<_1, _0, _0>>;

struct SharedStorage { };

// Host side epilogue arguments
Expand Down Expand Up @@ -194,7 +198,7 @@ class EpilogueTensorBroadcast {

auto stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC);
auto stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD);
auto stride_bias = detail::get_epilogue_stride<EpilogueSchedule>(Stride<_1, _0, _0>{});
auto stride_bias = detail::get_epilogue_stride<EpilogueSchedule>(BiasStride{});

// Represent the full output tensor
Tensor mC0_mnl = make_tensor(make_gmem_ptr(params.ptr_C0), make_shape(M,N,L), stride_c); // (m,n,l)
Expand Down
15 changes: 10 additions & 5 deletions test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ struct Testbed3xTensorBroadcast {
static constexpr bool IsBinaryOp1Enabled = Epilogue::IsBinaryOp1Enabled;
static constexpr bool IsUnaryOpEnabled = Epilogue::IsUnaryOpEnabled;

static constexpr bool PerColBias = Epilogue::PerColumnBias;

using LayoutTagA = typename TestBedImpl::LayoutTagA;
using LayoutTagB = typename TestBedImpl::LayoutTagB;
using LayoutTagC = typename TestBedImpl::LayoutTagC;
Expand Down Expand Up @@ -130,8 +132,8 @@ struct Testbed3xTensorBroadcast {

void initialize_bias(ProblemShapeType problem_size) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto M = cute::get<0>(problem_shape_MNKL);
bias.resize(cutlass::Coord<1>(M));
auto bias_size = PerColBias ? cute::get<1>(problem_shape_MNKL) : cute::get<0>(problem_shape_MNKL);
bias.resize(cutlass::Coord<1>(bias_size));

EXPECT_TRUE(impl_.initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.seed + 2023));
bias.sync_device();
Expand Down Expand Up @@ -186,7 +188,8 @@ struct Testbed3xTensorBroadcast {
std::ofstream file(fname.str());
file
<< "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L
<< ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias << "\n\n";
<< ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias
<< ", per-col bias: " << PerColBias << "\n\n";

if (use_bias){
file << "Bias = \n" << bias.host_view()<< "\n\n";
Expand Down Expand Up @@ -225,7 +228,7 @@ struct Testbed3xTensorBroadcast {
auto D = cute::make_tensor(impl_.reference_D.host_data(),
cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d));
auto Bias = cute::make_tensor(static_cast<ElementBias*>(use_bias ? bias.host_data() : nullptr),
cute::make_layout(cute::make_shape(M, 1)));
cute::make_layout(PerColBias ? cute::make_shape(1, N) : cute::make_shape(M, 1)));
auto C0 = cute::make_tensor(impl_.tensor_C.host_data(),
cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c));
auto C1 = cute::make_tensor(tensor_C1.host_data(),
Expand Down Expand Up @@ -263,7 +266,9 @@ struct Testbed3xTensorBroadcast {
decltype(dummy_Aux),
decltype(dummy_Valpha),
decltype(dummy_Vbeta),
ActivationFunctor> epilogue_params{
ActivationFunctor,
cutlass::plus<ElementCompute>,
PerColBias> epilogue_params{
alpha,
dummy_beta,
dummy_C,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,54 @@ TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32_tensor_broadcast, 64x128
EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast<Gemm>());
}

TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32_tensor_broadcast, 64x128x32_1x2x1_ActReLU_Bin0Mul_Bin1Plus_UnaryHardSwish_PerColBias) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;

using ElementOutput = float;
using ElementAccumulator = ElementOutput;
using ElementCompute = ElementOutput;
using ElementBias = ElementOutput;

using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
float, LayoutA, 4,
float, LayoutB, 4,
float,
Shape<_64,_128,_128>, Shape<_1,_2,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

using EpilogueOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
cutlass::epilogue::collective::EpilogueTensorBroadcast<
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::epilogue::thread::LinearCombinationTensorBroadcast<
ElementOutput, ElementAccumulator, ElementCompute, ElementBias,
cutlass::epilogue::thread::ReLu,
cutlass::multiplies,
cutlass::plus,
cutlass::epilogue::thread::HardSwish
>,
cutlass::gemm::EpilogueDefault,
/* PerColBias = */ true>>;

EXPECT_TRUE(EpilogueOp::IsBinaryOp0Enabled);
EXPECT_TRUE(EpilogueOp::IsBinaryOp1Enabled);
EXPECT_TRUE(EpilogueOp::IsUnaryOpEnabled);

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
EpilogueOp
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast<Gemm>());
}

///////////////////////////////////////////////////////////////////////////////

#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
9 changes: 7 additions & 2 deletions tools/util/include/cutlass/util/reference/host/gett.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ template<
class VectorAlpha_ = TensorD_, // (M, 1)
class VectorBeta_ = VectorAlpha_, // (M, 1)
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
bool PerColumnBias_ = false
>
struct GettEpilogueParams {
using ElementScalar = ElementScalar_;
Expand All @@ -114,6 +115,8 @@ struct GettEpilogueParams {
using EngineD = typename TensorD::engine_type;
using LayoutD = typename TensorD::layout_type;

static constexpr bool PerColumnBias = PerColumnBias_;

ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);

Expand Down Expand Up @@ -256,6 +259,8 @@ void gett_epilogue(
using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;

constexpr bool PerColBias = EpilogueParams::PerColumnBias;

constexpr bool IsScalingAndAmaxOutputNeeded =
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
Expand Down Expand Up @@ -334,7 +339,7 @@ void gett_epilogue(
ElementCompute output = mul(converted_alpha, converted_acc);

if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(m + m_b));
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
output = bias_op(output, converted_bias);
}

Expand Down

0 comments on commit d4be5ab

Please sign in to comment.