Skip to content

Commit

Permalink
Support ElementD to be void for tma (NVIDIA#1153)
Browse files Browse the repository at this point in the history
* Support void D with AuxStore

* refine get_element_aux
  • Loading branch information
kongroo authored Jan 16, 2024
1 parent 751eb9a commit 362abbf
Show file tree
Hide file tree
Showing 6 changed files with 807 additions and 47 deletions.
18 changes: 12 additions & 6 deletions include/cutlass/epilogue/collective/builders/sm90_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,21 @@ template <
class ElementC_,
class GmemLayoutTagC_,
int AlignmentC,
class ElementD,
class ElementD_,
class GmemLayoutTagD,
int AlignmentD,
class FusionOpOrCallbacks,
class DispatchPolicy
>
struct Sm90TmaBuilderImpl {
// Passing void D disables destination store + smem allocation
using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>,
fusion::get_element_aux_t<FusionOpOrCallbacks>, ElementD_>;

// Passing void C disables source load + smem allocation
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>;

using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;

Expand Down Expand Up @@ -292,7 +296,7 @@ struct Sm90TmaBuilderImpl {
EpilogueTile_MN,
ElementC_, // Need to pass void through to expose via GemmUniversal
GmemStrideTypeC,
ElementD,
ElementD_,
GmemStrideTypeD,
FusionCallbacks,
CopyOpG2S,
Expand Down Expand Up @@ -474,7 +478,7 @@ template <
class ElementC,
class GmemLayoutTagC,
int AlignmentC,
class ElementD,
class ElementD_,
class GmemLayoutTagD,
int AlignmentD,
class Schedule,
Expand All @@ -491,14 +495,16 @@ struct CollectiveBuilder<
ElementC,
GmemLayoutTagC,
AlignmentC,
ElementD,
ElementD_,
GmemLayoutTagD,
AlignmentD,
Schedule,
FusionOperation,
cute::enable_if_t<cute::is_same_v<Schedule, TmaWarpSpecialized> ||
cute::is_same_v<Schedule, TmaWarpSpecializedCooperative> >> {
private:
using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>,
fusion::get_element_aux_t<FusionOperation>, ElementD_>;
using EpilogueTile_MN =
decltype(detail::sm90_compute_tile_shape_or_override<ElementD, EpilogueTileType, Schedule, TileShape_MNK>());
using DispatchPolicy =
Expand All @@ -514,7 +520,7 @@ public:
ElementC,
GmemLayoutTagC,
AlignmentC,
ElementD,
ElementD_,
GmemLayoutTagD,
AlignmentD,
FusionOperation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,14 @@ class CollectiveEpilogue<
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]");

private:
using SmemElementC = cute::conditional_t<cute::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD>;
using SmemElementD = cute::conditional_t<not is_destination_supported,fusion::get_element_aux_t<FusionCallbacks>, ElementD>;
static_assert(not cute::is_void_v<SmemElementD>, "SmemElementD is void");
using SmemElementC = cute::conditional_t<not is_source_supported,SmemElementD,ElementC>; // prevents void ref breakages
constexpr static int StagesC = StagesC_;
constexpr static int StagesD = StagesD_;
constexpr static bool ReuseSmemC = ReuseSmemC_;
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported;

constexpr static bool is_m_major_C = detail::is_m_major<StrideC>();
constexpr static bool is_m_major_D = detail::is_m_major<StrideD>();
Expand All @@ -139,23 +142,33 @@ class CollectiveEpilogue<
make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int<ReuseSmemC ? StagesC : StagesD>{}),
cute::conditional_t<is_m_major_D, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));

constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC
constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC
&& cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{}));
static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met");

constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{});

using EmptyType = cute::tuple<>;
using SmemCStorage = cute::conditional_t<is_source_supported and (not ReuseSmemC),
array_aligned<SmemElementC, size(SmemLayoutC{}), SmemAlignmentC>,
EmptyType>;
using SmemDStorage = cute::conditional_t<is_destination_supported,
array_aligned<SmemElementD, size(SmemLayoutD{}), SmemAlignmentD>,
EmptyType>;

struct TensorStorageImpl: cute::tuple<SmemCStorage, SmemDStorage> {
using Base = cute::tuple<SmemCStorage, SmemDStorage>;

constexpr decltype(auto)
smem_C() {
return cute::get<0>(static_cast<Base &>(*this));
}

struct TensorStorageWithC {
alignas(SmemAlignmentC) array_aligned<SmemElementC, size(SmemLayoutC{})> smem_C;
alignas(SmemAlignmentD) array_aligned<ElementD, size(SmemLayoutD{})> smem_D;

using FusionStorage = typename FusionCallbacks::SharedStorage;
FusionStorage thread;
};

struct TensorStorageWithoutC {
alignas(SmemAlignmentD) array_aligned<ElementD, size(SmemLayoutD{})> smem_D;
constexpr decltype(auto)
smem_D() {
return cute::get<1>(static_cast<Base &>(*this));
}

using FusionStorage = typename FusionCallbacks::SharedStorage;
FusionStorage thread;
Expand All @@ -175,8 +188,7 @@ class CollectiveEpilogue<
using StorePipelineState = cutlass::PipelineState<ReuseSmemC ? StagesC : StagesD>;

struct SharedStorage {
using TensorStorage =
cute::conditional_t<not is_source_supported or ReuseSmemC, TensorStorageWithoutC, TensorStorageWithC>;
using TensorStorage = TensorStorageImpl;
TensorStorage tensors;

using PipelineStorage = typename LoadPipeline::SharedStorage;
Expand All @@ -203,7 +215,7 @@ class CollectiveEpilogue<
SmemLayoutC{}(_,_,0)));
using TMA_D = decltype(make_tma_copy(
CopyOpS2G{},
make_tensor(make_gmem_ptr(static_cast<ElementD const*>(nullptr)),
make_tensor(make_gmem_ptr(static_cast<SmemElementD const*>(nullptr)),
repeat_like(StrideD{}, int32_t(0)), StrideD{}),
SmemLayoutD{}(_,_,0)));

Expand Down Expand Up @@ -233,16 +245,16 @@ class CollectiveEpilogue<
;

typename Params::TMA_C tma_load_c;
if constexpr (not cute::is_void_v<ElementC>) {
if constexpr (is_source_supported) {
Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC));
tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0));
}

Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD));
typename Params::TMA_D tma_store_d = make_tma_copy(
CopyOpS2G{},
tensor_d,
SmemLayoutD{}(_,_,0));
typename Params::TMA_D tma_store_d;
if constexpr (is_destination_supported) {
Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD));
tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutD{}(_,_,0));
}

return {
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace),
Expand Down Expand Up @@ -272,8 +284,11 @@ class CollectiveEpilogue<
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;

constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits<ElementD>::value;
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(cute::make_shape(M,N,L), StrideD{});
bool implementable = true;
if constexpr (is_destination_supported) {
constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits<ElementD>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_D>(cute::make_shape(M,N,L), StrideD{});
}

if constexpr (not cute::is_void_v<ElementC>) {
constexpr int min_tma_aligned_elements_C = tma_alignment_bits / cutlass::sizeof_bits<ElementC>::value;
Expand Down Expand Up @@ -309,8 +324,12 @@ class CollectiveEpilogue<
CUTLASS_DEVICE
static void
prefetch_tma_descriptors(Params const& epilogue_params) {
cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor());
cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor());
if constexpr (is_source_supported) {
cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor());
}
if constexpr (is_destination_supported) {
cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor());
}
}

CUTLASS_HOST_DEVICE
Expand Down Expand Up @@ -365,9 +384,14 @@ class CollectiveEpilogue<
Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N)

// Apply epilogue subtile, get matching smem tensor
SmemElementC* ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D.data());
if constexpr (not ReuseSmemC and is_source_supported) {
ptr_sC = shared_tensors.smem_C.data();
SmemElementC* ptr_sC = nullptr;

if constexpr (is_source_supported) {
if constexpr (ReuseSmemC) {
ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D().data());
} else {
ptr_sC = shared_tensors.smem_C().data();
}
}
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C)
Expand Down Expand Up @@ -499,11 +523,20 @@ class CollectiveEpilogue<
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)

// Construct the corresponding pipelined smem tensors
SmemElementC* ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D.data());
if constexpr (not ReuseSmemC and is_source_supported) {
ptr_sC = shared_tensors.smem_C.data();
SmemElementC* ptr_sC = nullptr;
if constexpr (is_source_supported) {
if constexpr (ReuseSmemC) {
ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D().data());
} else {
ptr_sC = shared_tensors.smem_C().data();
}
}
ElementD* ptr_sD = shared_tensors.smem_D.data();

SmemElementD* ptr_sD = nullptr;
if constexpr (is_destination_supported) {
ptr_sD = shared_tensors.smem_D().data();
}

Tensor sC_epi = cute::as_position_independent_swizzle_tensor(
make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C)
Tensor sD_epi = cute::as_position_independent_swizzle_tensor(
Expand All @@ -514,19 +547,19 @@ class CollectiveEpilogue<
TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma);

// (t)hread-partition for (r)egister to (s)mem copy (tRS_)
TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom<CopyOpR2S,ElementD>{}, tiled_copy_C_atom);
TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom<CopyOpR2S,SmemElementD>{}, tiled_copy_C_atom);
ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx);
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D)

// Allocate D registers
Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi))));
Tensor tRS_rD = make_tensor<ElementD>(tRS_rD_layout); // (R2S,R2S_M,R2S_N)
Tensor tRS_rD = make_tensor<SmemElementD>(tRS_rD_layout); // (R2S,R2S_M,R2S_N)

// Vectorized fragment view
constexpr int FragmentSize = DispatchPolicy::FragmentSize;
Tensor tRS_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(tRS_rAcc);
Tensor tRS_rD_frg = recast<Array<ElementD , FragmentSize>>(tRS_rD);
Tensor tRS_rD_frg = recast<Array<SmemElementD , FragmentSize>>(tRS_rD);
CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly");

// (t)hread-partition for (s)mem to (r)egister copy (tSR_)
Expand Down Expand Up @@ -653,7 +686,9 @@ class CollectiveEpilogue<
}

// Copy tile from register to smem
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
if constexpr (is_destination_supported) {
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
}

// Post visit, pre async fence callback entry point
constexpr bool issue_smem_store = true; // No smem store predication
Expand All @@ -662,8 +697,10 @@ class CollectiveEpilogue<
// Write the tile from smem to gmem with TMA
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
synchronize(); // ensure all threads have issued their async fence
if (issue_tma_store) {
copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n));
if constexpr (is_destination_supported) {
if (issue_tma_store) {
copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n));
}
}

// Post async fence, pre TMA commit callback entry point
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,33 @@ struct FusionCallbacks<
};

/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <class FusionOpOrCallbacks, class = cute::void_t<>>
struct get_element_aux {
using type = void;
};

template <class FusionOpOrCallbacks>
struct get_element_aux<FusionOpOrCallbacks, cute::void_t<typename FusionOpOrCallbacks::ElementAux>> {
using type = typename FusionOpOrCallbacks::ElementAux;
};

template <class NodeOp, class... ChildOps>
struct get_element_aux<Sm90TreeVisitor<NodeOp, ChildOps...>, cute::void_t<>> {
using type = typename get_element_aux<NodeOp>::type;
};

template <class... Ts>
struct get_element_aux<FusionCallbacks<Ts...>, cute::void_t<typename FusionCallbacks<Ts...>::Operation>> {
private:
using Operation = typename FusionCallbacks<Ts...>::Operation;
public:
using type = typename get_element_aux<Operation>::type;
};
}

template <class Callbacks>
using get_element_aux_t = typename detail::get_element_aux<Callbacks>::type;

} // namespace cutlass::epilogue::fusion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ template <
bool EnableNullptr = true // Noop on nullptr params
>
struct Sm90AuxStore {
using ElementAux = Element;
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");

constexpr static bool is_m_major = epilogue::collective::detail::is_m_major<StrideMNL>();
Expand Down
1 change: 1 addition & 0 deletions test/unit/gemm/device/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ cutlass_test_unit_add_executable(
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_reduce.cu
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_dag.cu
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_dag.cu
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90
Expand Down
Loading

0 comments on commit 362abbf

Please sign in to comment.