diff --git a/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp index accc6d9d8b..a036df2f03 100644 --- a/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp +++ b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp @@ -69,7 +69,8 @@ template < class StrideC_, class StrideD_, class ThreadEpilogueOp_, - class EpilogueSchedule_ + class EpilogueSchedule_, + bool PerColumnBias_ = false > class EpilogueTensorBroadcast { public: @@ -101,6 +102,9 @@ class EpilogueTensorBroadcast { static constexpr bool IsBinaryOp1Enabled = ThreadEpilogueOp::IsBinaryOp1Enabled; static constexpr bool IsUnaryOpEnabled = ThreadEpilogueOp::IsUnaryOpEnabled; + static constexpr bool PerColumnBias = PerColumnBias_; + using BiasStride = typename cute::conditional_t, Stride<_1, _0, _0>>; + struct SharedStorage { }; // Host side epilogue arguments @@ -194,7 +198,7 @@ class EpilogueTensorBroadcast { auto stride_c = detail::get_epilogue_stride(params.dC); auto stride_d = detail::get_epilogue_stride(params.dD); - auto stride_bias = detail::get_epilogue_stride(Stride<_1, _0, _0>{}); + auto stride_bias = detail::get_epilogue_stride(BiasStride{}); // Represent the full output tensor Tensor mC0_mnl = make_tensor(make_gmem_ptr(params.ptr_C0), make_shape(M,N,L), stride_c); // (m,n,l) diff --git a/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp index ec7fbea792..0cabaa72f4 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp @@ -76,6 +76,8 @@ struct Testbed3xTensorBroadcast { static constexpr bool IsBinaryOp1Enabled = Epilogue::IsBinaryOp1Enabled; static constexpr bool IsUnaryOpEnabled = Epilogue::IsUnaryOpEnabled; + static constexpr bool PerColBias = Epilogue::PerColumnBias; + using LayoutTagA = typename TestBedImpl::LayoutTagA; using LayoutTagB = typename TestBedImpl::LayoutTagB; using LayoutTagC = typename TestBedImpl::LayoutTagC; @@ -130,8 +132,8 @@ struct Testbed3xTensorBroadcast { void initialize_bias(ProblemShapeType problem_size) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::get<0>(problem_shape_MNKL); - bias.resize(cutlass::Coord<1>(M)); + auto bias_size = PerColBias ? cute::get<1>(problem_shape_MNKL) : cute::get<0>(problem_shape_MNKL); + bias.resize(cutlass::Coord<1>(bias_size)); EXPECT_TRUE(impl_.initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.seed + 2023)); bias.sync_device(); @@ -186,7 +188,8 @@ struct Testbed3xTensorBroadcast { std::ofstream file(fname.str()); file << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L - << ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias << "\n\n"; + << ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias + << ", per-col bias: " << PerColBias << "\n\n"; if (use_bias){ file << "Bias = \n" << bias.host_view()<< "\n\n"; @@ -225,7 +228,7 @@ struct Testbed3xTensorBroadcast { auto D = cute::make_tensor(impl_.reference_D.host_data(), cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d)); auto Bias = cute::make_tensor(static_cast(use_bias ? bias.host_data() : nullptr), - cute::make_layout(cute::make_shape(M, 1))); + cute::make_layout(PerColBias ? cute::make_shape(1, N) : cute::make_shape(M, 1))); auto C0 = cute::make_tensor(impl_.tensor_C.host_data(), cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); auto C1 = cute::make_tensor(tensor_C1.host_data(), @@ -263,7 +266,9 @@ struct Testbed3xTensorBroadcast { decltype(dummy_Aux), decltype(dummy_Valpha), decltype(dummy_Vbeta), - ActivationFunctor> epilogue_params{ + ActivationFunctor, + cutlass::plus, + PerColBias> epilogue_params{ alpha, dummy_beta, dummy_C, diff --git a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu index 735d14fb90..99370aa0c8 100644 --- a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu +++ b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu @@ -97,6 +97,54 @@ TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32_tensor_broadcast, 64x128 EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast()); } +TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32_tensor_broadcast, 64x128x32_1x2x1_ActReLU_Bin0Mul_Bin1Plus_UnaryHardSwish_PerColBias) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using ElementOutput = float; + using ElementAccumulator = ElementOutput; + using ElementCompute = ElementOutput; + using ElementBias = ElementOutput; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + float, LayoutA, 4, + float, LayoutB, 4, + float, + Shape<_64,_128,_128>, Shape<_1,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::EpilogueTensorBroadcast< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombinationTensorBroadcast< + ElementOutput, ElementAccumulator, ElementCompute, ElementBias, + cutlass::epilogue::thread::ReLu, + cutlass::multiplies, + cutlass::plus, + cutlass::epilogue::thread::HardSwish + >, + cutlass::gemm::EpilogueDefault, + /* PerColBias = */ true>>; + + EXPECT_TRUE(EpilogueOp::IsBinaryOp0Enabled); + EXPECT_TRUE(EpilogueOp::IsBinaryOp1Enabled); + EXPECT_TRUE(EpilogueOp::IsUnaryOpEnabled); + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast()); +} + /////////////////////////////////////////////////////////////////////////////// #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 70c8c0da2b..0f7b19f9b8 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -93,7 +93,8 @@ template< class VectorAlpha_ = TensorD_, // (M, 1) class VectorBeta_ = VectorAlpha_, // (M, 1) class ActivationFunctor_ = cutlass::epilogue::thread::Identity, - class BiasBinaryOp_ = cutlass::plus + class BiasBinaryOp_ = cutlass::plus, + bool PerColumnBias_ = false > struct GettEpilogueParams { using ElementScalar = ElementScalar_; @@ -114,6 +115,8 @@ struct GettEpilogueParams { using EngineD = typename TensorD::engine_type; using LayoutD = typename TensorD::layout_type; + static constexpr bool PerColumnBias = PerColumnBias_; + ElementScalar alpha = ElementScalar(1); ElementScalar beta = ElementScalar(0); @@ -256,6 +259,8 @@ void gett_epilogue( using ActivationFunctor = typename EpilogueParams::ActivationFunctor; using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; + constexpr bool PerColBias = EpilogueParams::PerColumnBias; + constexpr bool IsScalingAndAmaxOutputNeeded = cute::is_same_v or cute::is_same_v; @@ -334,7 +339,7 @@ void gett_epilogue( ElementCompute output = mul(converted_alpha, converted_acc); if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { - ElementCompute converted_bias = bias_converter(epilogue_params.Bias(m + m_b)); + ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); output = bias_op(output, converted_bias); }