Skip to content

Commit

Permalink
Add conversion from ElementBias to ElementCompute (NVIDIA#961)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkosaian authored May 27, 2023
1 parent 6f47420 commit 7dbf423
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,12 @@ class CollectiveEpilogue<
Tensor tRS_rT_frg = recast<typename ThreadEpilogueOp::FragmentT>(tRS_rT);
Tensor tRS_rBias_frg = recast<typename ThreadEpilogueOp::FragmentBias>(tRS_rBias);

// thread::LinearCombinationBiasElementwise expects that the bias passed in is of
// type ElementCompute. Therefore, conversion from type ElementBias to ElementCompute
// is needed before calling the thread-level epilogue.
cutlass::NumericArrayConverter<ElementCompute, ElementBias,
ThreadEpilogueOp::FragmentBias::kElements> bias_converter;

// Partition for smem to register copy (tSR_)
TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom<CopyOpS2R,InternalElementC>{}, tiled_r2s);
ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx);
Expand Down Expand Up @@ -538,13 +544,15 @@ class CollectiveEpilogue<

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tRS_rD_frg); ++i) {
epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i), tRS_rBias_frg(i));
typename ThreadEpilogueOp::FragmentCompute converted_bias = bias_converter(tRS_rBias_frg(i));
epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i), converted_bias);
}
}
else {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tRS_rD_frg); ++i) {
epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rBias_frg(i));
typename ThreadEpilogueOp::FragmentCompute converted_bias = bias_converter(tRS_rBias_frg(i));
epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), converted_bias);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class LinearCombinationBiasElementwise {
using FragmentSource = FragmentC;
using FragmentOutput = FragmentZ;
using ElementBias = ElementVector;
using FragmentBias = FragmentCompute;
using FragmentBias = Array<ElementBias, kElementsPerAccess>;
using ActivationFunctor = ElementwiseOp;
static const ScaleType::Kind kScale = ScaleType::Default;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_ReLU) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
Expand Down Expand Up @@ -144,7 +144,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_GELU) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_GELU) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
Expand Down Expand Up @@ -188,7 +188,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_ReLU_NoStoreT) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU_NoStoreT) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
Expand Down Expand Up @@ -231,7 +231,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_Negate) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_Negate) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
Expand Down Expand Up @@ -275,7 +275,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU) {
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
Expand Down Expand Up @@ -319,7 +319,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
Expand Down Expand Up @@ -363,7 +363,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU_VoidC) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU_VoidC) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
Expand Down Expand Up @@ -407,4 +407,92 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
EXPECT_TRUE(passed);
}

#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF16Mul_ReLU_VoidC) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_256,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;

static constexpr bool StoreT = true;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise<
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, cutlass::half_t>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
EpilogueSchedule
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasS8Mul_ReLU_VoidC) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_256,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;

static constexpr bool StoreT = true;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise<
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, int8_t>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
EpilogueSchedule
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
EXPECT_TRUE(passed);
}

#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_ReLU) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_ReLU) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
Expand Down Expand Up @@ -143,7 +143,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_GELU) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_GELU) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
Expand Down Expand Up @@ -187,7 +187,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_ReLU_NoStoreT) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_ReLU_NoStoreT) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
Expand Down Expand Up @@ -230,7 +230,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_Negate) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_Negate) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
Expand Down Expand Up @@ -274,7 +274,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU) {
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
Expand Down Expand Up @@ -318,7 +318,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
Expand Down Expand Up @@ -362,7 +362,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU_VoidC) {
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU_VoidC) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
Expand Down Expand Up @@ -406,4 +406,92 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
EXPECT_TRUE(passed);
}

#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF16Mul_ReLU_VoidC) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_128,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;

static constexpr bool StoreT = true;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise<
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, cutlass::half_t>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
EpilogueSchedule
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
EXPECT_TRUE(passed);
}

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasS8Mul_ReLU_VoidC) {

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_128,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;

static constexpr bool StoreT = true;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise<
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, int8_t>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
EpilogueSchedule
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
EXPECT_TRUE(passed);
}

#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

0 comments on commit 7dbf423

Please sign in to comment.