Skip to content

Commit

Permalink
Add H100 support for the MoE GeMM kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaYazdaniAminabadi committed Jan 12, 2024
1 parent bdb30fd commit 1b66da8
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,31 @@ struct MixedGemmArchTraits<
private:
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm80>;

public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;

using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;

static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;

using Operator = typename LayoutDetails::Operator;
};
// ======================= Ampere Traits ==============================
template<typename TypeA, typename TypeB>
struct MixedGemmArchTraits<
TypeA,
TypeB,
cutlass::arch::Sm90,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
private:
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm90>;

public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ template<typename ThreadblockShape,
int ThreadCount,
bool Transposed = false>
struct GemmMoeProblemVisitor:
public MoeProblemVisitor<detail::GemmGroupedProblemSizeHelper<Transposed>,
public MoeProblemVisitor<detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount> {

static bool const kTransposed = Transposed;

using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<Transposed>;
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
using Base =
MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
using Params = typename Base::Params;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,12 @@ struct MoeFCGemm {
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,57 @@ struct dispatch_stages {
}
};

template<typename T,
typename WeightType,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape,
int Stages>
struct dispatch_stages<T,
WeightType,
cutlass::arch::Sm90,
EpilogueTag,
ThreadblockShape,
WarpShape,
Stages,
typename std::enable_if<(Stages > 2)>::type> {
static void dispatch(const T* A,
const WeightType* B,
const T* weight_scales,
const T* biases,
T* C,
int64_t* total_rows_before_expert,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
CutlassGemmConfig gemm_config,
int multi_processor_count,
cudaStream_t stream,
int* occupancy = nullptr)
{
generic_moe_gemm_kernelLauncher<T,
WeightType,
cutlass::arch::Sm90,
EpilogueTag,
ThreadblockShape,
WarpShape,
Stages>(A,
B,
weight_scales,
biases,
C,
total_rows_before_expert,
gemm_n,
gemm_k,
num_experts,
gemm_config,
multi_processor_count,
stream,
occupancy);
}
};


template<typename T,
typename WeightType,
typename arch,
Expand Down Expand Up @@ -726,6 +777,27 @@ void MoeGemmRunner<T, V>::dispatch_to_arch<EpilogueTag>(const T* A,
stream,
occupancy);
}
else if (sm_ >=90) {

dispatch_moe_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(A, B, weight_scales, biases, C,
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
stream, occupancy);
//dispatch_moe_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm90, EpilogueTag>(A,
// B,
// weight_scales,
// biases,
// C,
// total_rows_before_expert,
// total_rows,
// gemm_n,
// gemm_k,
// num_experts,
// gemm_config,
// sm_,
// multi_processor_count_,
// stream,
// occupancy);
}
else {
throw std::runtime_error("[FT Error][MoE][GEMM Dispatch] Arch unsupported for MoE GEMM");
}
Expand Down
2 changes: 1 addition & 1 deletion dskernels/ft_gemm/third_party/cutlass
Submodule cutlass updated 2073 files

0 comments on commit 1b66da8

Please sign in to comment.