From 1b66da8e791270287d582f3cd999a01c0955e9d7 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 12 Jan 2024 07:21:34 +0000 Subject: [PATCH] Add H100 support for the MoE GeMM kernels --- .../gemm/kernel/default_fpA_intB_traits.h | 25 +++++++ .../gemm/kernel/gemm_moe_problem_visitor.h | 4 +- .../gemm/kernel/moe_cutlass_kernel.h | 6 ++ .../moe_gemm/moe_gemm_kernels_template.h | 72 +++++++++++++++++++ dskernels/ft_gemm/third_party/cutlass | 2 +- 5 files changed, 106 insertions(+), 3 deletions(-) diff --git a/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index a903254..f38c2ed 100644 --- a/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ b/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -103,6 +103,31 @@ struct MixedGemmArchTraits< private: using LayoutDetails = LayoutDetailsB; +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm90, + typename cutlass::platform::enable_if::value + || cutlass::platform::is_same::value>::type> { +private: + using LayoutDetails = LayoutDetailsB; + public: static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; diff --git a/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h b/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h index 24c4d7e..fcf12f1 100644 --- a/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h +++ b/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h @@ -56,7 +56,7 @@ template struct GemmMoeProblemVisitor: - public MoeProblemVisitor, + public MoeProblemVisitor, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, @@ -64,7 +64,7 @@ struct GemmMoeProblemVisitor: static bool const kTransposed = Transposed; - using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; using Base = MoeProblemVisitor; using Params = typename Base::Params; diff --git a/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h index fa193e4..6ec7080 100644 --- a/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h +++ b/dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -516,6 +516,12 @@ struct MoeFCGemm { static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, shared_storage); #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage);#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available + static constexpr bool compile_needed = platform::is_same::value; +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, shared_storage); #else diff --git a/dskernels/ft_gemm/gemm_variants/moe_gemm/moe_gemm_kernels_template.h b/dskernels/ft_gemm/gemm_variants/moe_gemm/moe_gemm_kernels_template.h index 049beed..9748b10 100644 --- a/dskernels/ft_gemm/gemm_variants/moe_gemm/moe_gemm_kernels_template.h +++ b/dskernels/ft_gemm/gemm_variants/moe_gemm/moe_gemm_kernels_template.h @@ -231,6 +231,57 @@ struct dispatch_stages { } }; +template +struct dispatch_stages 2)>::type> { + static void dispatch(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) + { + generic_moe_gemm_kernelLauncher(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + } +}; + + template::dispatch_to_arch(const T* A, stream, occupancy); } + else if (sm_ >=90) { + + dispatch_moe_gemm_to_cutlass(A, B, weight_scales, biases, C, + total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, + stream, occupancy); + //dispatch_moe_gemm_to_cutlass(A, + // B, + // weight_scales, + // biases, + // C, + // total_rows_before_expert, + // total_rows, + // gemm_n, + // gemm_k, + // num_experts, + // gemm_config, + // sm_, + // multi_processor_count_, + // stream, + // occupancy); + } else { throw std::runtime_error("[FT Error][MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); } diff --git a/dskernels/ft_gemm/third_party/cutlass b/dskernels/ft_gemm/third_party/cutlass index cc85b64..acba5be 160000 --- a/dskernels/ft_gemm/third_party/cutlass +++ b/dskernels/ft_gemm/third_party/cutlass @@ -1 +1 @@ -Subproject commit cc85b64cf676c45f98a17e3a47c0aafcf817f088 +Subproject commit acba5beee568792da609ef27275fe9e459a36a25