From 3675ea2d90133f4250dd761525117927af508a77 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Wed, 11 Sep 2024 21:21:04 +0800 Subject: [PATCH] [Inference] FP8 gemm auto-tune (#9094) * fp8 cutlass gemm tune * git ignore third_party * check csrc/readme.md --- .gitignore | 2 +- csrc/README.md | 13 +- csrc/generate_code_gemm_fused_kernels.py | 537 ++++++++++++++++++ csrc/gpu/cutlass_kernels/cutlass_helper.h | 71 +++ .../fp8_fp8_dual_gemm_scale_bias_act.cu | 2 +- .../fp8_fp8_gemm_scale_bias_act.cu | 125 ---- .../fp8_fp8_gemm_scale_bias_act.h | 3 + .../fp8_gemm_fused/fuse_gemm_gelu_template.h | 456 +++++++++++++++ .../fp8_gemm_fused/fuse_gemm_noact_template.h | 311 ++++++++++ .../fp8_gemm_fused/fuse_gemm_relu_template.h | 311 ++++++++++ .../fp8_gemm_fused/gemm_scale.h | 162 ------ .../fp8_gemm_fused/gemm_scale_bias.h | 163 ------ .../fp8_gemm_fused/gemm_scale_bias_gelu.h | 163 ------ .../fp8_gemm_fused/gemm_scale_bias_relu.h | 162 ------ .../fp8_gemm_fused/gemm_scale_gelu.h | 162 ------ csrc/gpu/fp8_gemm_with_cutlass/fp8_common.h | 18 +- .../fp8_fp8_fp8_dual_gemm.cu | 5 +- .../fp8_fp8_half_gemm.cu | 18 +- csrc/gpu/helper.h | 26 +- csrc/gpu/test_fp8_gemm.py | 54 ++ csrc/gpu/test_fp8gemm.py | 68 --- csrc/setup_cuda.py | 22 +- csrc/tune_fp8_gemm.sh | 29 + 23 files changed, 1850 insertions(+), 1033 deletions(-) create mode 100644 csrc/generate_code_gemm_fused_kernels.py delete mode 100644 csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.cu create mode 100644 csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h create mode 100644 csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h create mode 100644 csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h delete mode 100644 csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale.h delete mode 100644 csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias.h delete mode 100644 csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias_gelu.h delete mode 100644 csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias_relu.h delete mode 100644 csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_gelu.h create mode 100644 csrc/gpu/test_fp8_gemm.py delete mode 100644 csrc/gpu/test_fp8gemm.py create mode 100644 csrc/tune_fp8_gemm.sh diff --git a/.gitignore b/.gitignore index 3538bea189ad..8ac817c65d68 100644 --- a/.gitignore +++ b/.gitignore @@ -126,6 +126,6 @@ FETCH_HEAD ./ppdiffusers/ppdiffusers/version.py # third party -csrc/gpu/cutlass_kernels/cutlass +csrc/third_party/ dataset/ output/ diff --git a/csrc/README.md b/csrc/README.md index 0bd01e28365f..dddef3b433e0 100644 --- a/csrc/README.md +++ b/csrc/README.md @@ -10,6 +10,12 @@ pip install -r requirements.txt ## 编译 Cuda 算子 +生成 FP8的 cutlass 算子(编译耗时较长) +```shell +python generate_code_gemm_fused_kernels.py +``` + +编译 ```shell python setup_cuda.py install ``` @@ -20,9 +26,14 @@ python setup_cuda.py install 2. 拉取代码: git clone -b v3.5.0 --single-branch https://github.com/NVIDIA/cutlass.git -3. 将下载的 `cutlass` 目录放在 `csrc/gpu/cutlass_kernels/cutlass`下 +3. 将下载的 `cutlass` 目录放在 `csrc/third_party/cutlass`下 4. 重新编译 Cuda 算子 ```shell python setup_cuda.py install ``` + +### FP8 GEMM 自动调优 +```shell +sh tune_fp8_gemm.sh +``` diff --git a/csrc/generate_code_gemm_fused_kernels.py b/csrc/generate_code_gemm_fused_kernels.py new file mode 100644 index 000000000000..8834602fa2fe --- /dev/null +++ b/csrc/generate_code_gemm_fused_kernels.py @@ -0,0 +1,537 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import re + + +def get_candidate_tiles(): + base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")] + + base_configs.extend( + [ + ("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"), + ("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"), + ("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"), + ("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), + ("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"), + ("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), + ("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), + ("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"), + ("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"), + ("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), + ("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"), + ] + ) + + return base_configs + + +def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages): + tiles = get_candidate_tiles() + candidate_configs = list() + + stages = tuple(i for i in range(min_stages, max_stages + 1, 1)) + splitks = tuple(i for i in range(min_split_k, max_split_k + 1, 1)) + hasbias = ("false", "true") + + for act_tag in [ + ("noact", "LinearCombination"), + ("relu", "LinearCombinationRelu"), + ("gelu", "LinearCombinationGELU"), + ]: + candidate_configs.extend([(stages, splitks, tiles, act_tag, hasbias)]) + + return candidate_configs + + +# this is a file's header part +CommonHead = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit. + +#pragma once + +#include "fp8_gemm_fused/fuse_gemm_{act_tag}_template.h" + +""" + + +CommonTail = """ + +""" + +GemmDeclare = """ +template<> +bool dispatch_fuse_gemm_{act_tag}(GemmEpilogueAllParams); + + +""" + + +GemmSplitKDeclare = """ +template<> +bool dispatch_fuse_gemm_split_k_{act_tag}(GemmEpilogueAllParams); + + +""" + + +code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit. + +#include +#include "fp8_fp8_gemm_scale_bias_act.h" + +COMMON_DECLARE_string(use_cutlass_device_best_config_path); + +std::map config_map{""" + +code_part1 = """ + {"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """ + +code_part2 = """ +}; + +std::map gemm_configs_map{ +""" + +code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}}, +""" + +code_part4 = """}; + +bool launch_gemm_kernel(const int type_id, const int split_k, const int kernel_id, GemmEpilogueAllParams params){ + switch (type_id) {""" + +code_part5 = """ + case {type_id}: + if(split_k < 2){ + params.split_k = 1; + switch (kernel_id) {""" + +code_part6 = """ + case {tile_id}: + return dispatch_fuse_gemm_{act_tag}(params); + break;""" + +code_part7 = """ + default: + throw std::runtime_error("cutlass gemm config is invalid."); + break; + } + }else{ + params.split_k = split_k; + switch (kernel_id) {""" + +code_part8 = """ + case {tile_id}: + return dispatch_fuse_gemm_split_k_{act_tag}(params); + break;""" + +code_part9 = """ + default: + throw std::runtime_error("cutlass gemm config is invalid."); + break; + } + } + break;""" + +code_part10 = """ + default: + throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid."); + break; + } + return false; +} + + +bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) { + if (config_map.find(params.fuse_gemm_config) == config_map.end()) { + throw std::runtime_error("fp8 gemm_fused config is invalid."); + } + + int type_id = config_map[params.fuse_gemm_config]; + int M = (params.M+31)/32 *32; + int N = params.N; + int K = params.K; + + std::string mkn_string = "<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">"; + std::string mkn_split_k_string = "<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k"; + int split_k; + int kernel_id; + std::string best_config; + CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance(); + if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel + std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path"); + nlohmann::json* config_json = best_config_mannager.get_gemm_best_configs(config_file_path); + if (config_json->contains(mkn_string)) { + best_config = config_json->at(mkn_string); + } else { + std::cerr << "Can not find the config for this gemm shape, please tune this shape: " << mkn_string <contains(mkn_split_k_string)) { + split_k = config_json->at(mkn_split_k_string); + } else { + std::cerr << "Can not find the config(split_k) for this gemm shape, please tune this shape: " << mkn_string < +#include "helper.h" #include "cutlass/half.h" #include "cutlass/bfloat16.h" #include "paddle/extension.h" @@ -39,4 +42,72 @@ class CutlassDtypeTraits { public: typedef cutlass::bfloat16_t DataType; typedef paddle::bfloat16 data_t; +}; + +class CutlassGemmConfigMannager { +public: + static CutlassGemmConfigMannager& getInstance() { + static CutlassGemmConfigMannager instance; + return instance; + } + + CutlassGemmConfigMannager(const CutlassGemmConfigMannager&) = delete; + CutlassGemmConfigMannager& operator=(const CutlassGemmConfigMannager&) = delete; + + void up_date_configs(const nlohmann::json& j){ + std::lock_guard lock(mutex_); + for (auto it = j.begin(); it != j.end(); ++it) { + json_[it.key()] = it.value(); + } + } + + nlohmann::json* get_gemm_best_configs(const std::string & config_file_path) { + if (!load_initialized_) { + std::ifstream file(config_file_path); + if(!file.good()){ + throw std::runtime_error("cutlass gemm_best_config can not be found, please set gemm_best_config'path as FLAGS_use_cutlass_device_best_config_path, or unset FLAGS_use_cutlass_device_best_config_path to tune gemm_best_config"); + } + json_ = readJsonFromFile(config_file_path); + load_initialized_ = true; + save_initialized_ = false; + } + return &json_; + } + +private: + void save_gemm_best_configs_(const std::string & config_file_path) { + std::ifstream file(config_file_path); + if(!file.good()){ + std::ofstream new_file(config_file_path); + new_file << json_.dump(4); + new_file.close(); + } else { + nlohmann::json old_json = readJsonFromFile(config_file_path); + for (auto it = json_.begin(); it != json_.end(); ++it) { + old_json[it.key()] = it.value(); + } + json_ = old_json; + std::ofstream new_file(config_file_path, std::ios::out | std::ios::trunc); + new_file << json_.dump(4); + new_file.close(); + file.close(); + } + return; + } + + CutlassGemmConfigMannager() : json_(nullptr), load_initialized_(false) , save_initialized_(true){} + ~CutlassGemmConfigMannager() { + std::lock_guard lock(mutex_); + if(save_initialized_){ + std::string config_file_path = "fp8_fuse_gemm_config.json"; + save_gemm_best_configs_(config_file_path); + } + save_initialized_=true; + load_initialized_=false; + json_.clear(); + } + mutable std::mutex mutex_; + nlohmann::json json_; + bool load_initialized_; + bool save_initialized_; }; \ No newline at end of file diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.cu b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.cu index c2f38bfc15af..7dbb5fcc73cf 100644 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.cu +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.cu @@ -30,7 +30,7 @@ std::map config_map1{ }; bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) { - switch (config_map1[params.gemm_config]) { + switch (config_map1[params.fuse_gemm_config]) { case 0: dispatch_dual_gemm_scale_swiglu(params); diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.cu b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.cu deleted file mode 100644 index d4f8efc67aed..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.cu +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include - -#include "fp8_fp8_gemm_scale_bias_act.h" // NOLINT - -#include "gemm_scale.h" // NOLINT -#include "gemm_scale_bias.h" // NOLINT -#include "gemm_scale_bias_gelu.h" // NOLINT -#include "gemm_scale_bias_relu.h" // NOLINT -#include "gemm_scale_gelu.h" // NOLINT - - -std::map config_map{ - {"e4m3_bf16_identity", 0}, {"e4m3_bf16_bias_identity", 1}, - {"e4m3_bf16_bias_relu", 2}, {"e4m3_bf16_bias_gelu", 3}, - {"e4m3_bf16_gelu", 4}, {"e4m3_fp16_identity", 5}, - {"e4m3_fp16_bias_identity", 6}, {"e4m3_fp16_bias_relu", 7}, - {"e4m3_fp16_bias_gelu", 8}, {"e4m3_fp16_gelu", 9}, - {"e5m2_bf16_identity", 10}, {"e5m2_bf16_bias_identity", 11}, - {"e5m2_bf16_bias_relu", 12}, {"e5m2_bf16_bias_gelu", 13}, - {"e5m2_bf16_gelu", 14}, {"e5m2_fp16_identity", 15}, - {"e5m2_fp16_bias_identity", 16}, {"e5m2_fp16_bias_relu", 17}, - {"e5m2_fp16_bias_gelu", 18}, {"e5m2_fp16_gelu", 19}, -}; - -bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) { - switch (config_map[params.gemm_config]) { - case 0: - dispatch_gemm_scale( - params); - break; - case 1: - dispatch_gemm_scale_bias( - params); - break; - case 2: - dispatch_gemm_scale_bias_relu(params); - break; - case 3: - dispatch_gemm_scale_bias_gelu(params); - break; - case 4: - dispatch_gemm_scale_gelu( - params); - break; - case 5: - dispatch_gemm_scale( - params); - break; - case 6: - dispatch_gemm_scale_bias( - params); - break; - case 7: - dispatch_gemm_scale_bias_relu(params); - break; - case 8: - dispatch_gemm_scale_bias_gelu(params); - break; - case 9: - dispatch_gemm_scale_gelu( - params); - break; - case 10: - dispatch_gemm_scale( - params); - break; - case 11: - dispatch_gemm_scale_bias( - params); - break; - case 12: - dispatch_gemm_scale_bias_relu(params); - break; - case 13: - dispatch_gemm_scale_bias_gelu(params); - break; - case 14: - dispatch_gemm_scale_gelu( - params); - break; - case 15: - dispatch_gemm_scale(params); - break; - case 16: - dispatch_gemm_scale_bias( - params); - break; - case 17: - dispatch_gemm_scale_bias_relu(params); - break; - case 18: - dispatch_gemm_scale_bias_gelu(params); - break; - case 19: - dispatch_gemm_scale_gelu( - params); - break; - default: - throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid."); - break; - } - return false; -} - diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h index 5ba68c46fd07..8d87da7a9e1d 100644 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h @@ -15,6 +15,9 @@ #pragma once #include "fp8_common.h" +#include "fuse_gemm_noact_template.h" +#include "fuse_gemm_relu_template.h" +#include "fuse_gemm_gelu_template.h" bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params); diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h new file mode 100644 index 000000000000..a6c21b0a97de --- /dev/null +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h @@ -0,0 +1,456 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "fp8_fp8_gemm_scale_bias_act.h" // NOLINT + +#include "cutlass/cutlass.h" +#include "cutlass/float8.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_splitk_parallel.h" + +template +bool dispatch_fuse_gemm_gelu(GemmEpilogueAllParams params) { + using ElementInputA = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementInputB = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementOutput = + typename std::conditional_t, + cutlass::bfloat16_t, + cutlass::half_t>; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementComputeEpilogue = float; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + static int const kAlignmentA = 16; + static int const kAlignmentB = 16; + + // This code section describes whether you want to use tensor cores or regular + // SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = SM; + + // This code section describes the tile size a thread block will compute + using ShapeMMAThreadBlock = ThreadBlockShape; + + // This code section describes tile size a warp will compute + using ShapeMMAWarp = WarpShape; + + // This code section describes the size of MMA op + using ShapeMMAOp = MMAShape; // <- MMA Op tile + + static constexpr auto ScaleType = + hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // <- ?? + + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits:: + value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue, + ScaleType>; // <- data type for alpha/beta in linear + // combination function + + // Number of pipelines you want to use + constexpr int NumStages = Stages; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + kAlignmentA, + kAlignmentB, + cutlass::arch::OpMultiplyAddFastAccum>; // NOLINT + + cutlass::gemm::GemmCoord problem_size = + cutlass::gemm::GemmCoord{params.M, params.N, params.K}; + // cutlass::gemm::GemmUniversalMode mode = + // cutlass::gemm::GemmUniversalMode::kGemm; + + cutlass::gemm::GemmUniversalMode mode = + cutlass::gemm::GemmUniversalMode::kGemm; + // cutlass::gemm::BatchedGemmCoord problem_size = + // cutlass::gemm::BatchedGemmCoord{params.M, params.N, params.K, + // params.batch_count}; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(ElementCompute(params.scale), + ElementCompute(1.0)); + typename Gemm::Arguments arguments{ + mode, + problem_size, + params.batch_count, + epilogue_op, + reinterpret_cast(const_cast(params.A)), + reinterpret_cast(const_cast(params.B)), + reinterpret_cast(const_cast(params.bias)), + reinterpret_cast(params.D), + params.lda * params.M, + params.ldb * params.N, + (int64_t)0, + params.ldd * params.M, + params.lda, + params.ldb, + (int64_t)0, + params.ldd, + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::can_implement() failed" << std::endl; + return false; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + phi::Allocator* allocator = paddle::GetAllocator(params.place); + auto workspace = allocator->Allocate(workspace_size); + + // + // Run the GEMM + // + status = gemm_op(arguments, workspace->ptr(), params.stream); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::run() failed" << std::endl; + return false; + } + return true; +} + + +template +bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { + using ElementInputA = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementInputB = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementOutput = + typename std::conditional_t, + cutlass::bfloat16_t, + cutlass::half_t>; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementComputeEpilogue = float; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + static int const kAlignmentA = 16; + static int const kAlignmentB = 16; + + // This code section describes whether you want to use tensor cores or regular + // SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = SM; + + // This code section describes the tile size a thread block will compute + using ShapeMMAThreadBlock = ThreadBlockShape; + + // This code section describes tile size a warp will compute + using ShapeMMAWarp = WarpShape; + + // This code section describes the size of MMA op + using ShapeMMAOp = MMAShape; // <- MMA Op tile + + static constexpr auto ScaleType = + hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // <- ?? + + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits:: + value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue, + ScaleType>; // <- data type for alpha/beta in linear + // combination function + + // Number of pipelines you want to use + constexpr int NumStages = Stages; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + kAlignmentA, + kAlignmentB, + cutlass::arch::OpMultiplyAddFastAccum>; // NOLINT + + cutlass::gemm::GemmCoord problem_size = + cutlass::gemm::GemmCoord{params.M, params.N, params.K}; + // cutlass::gemm::GemmUniversalMode mode = + // cutlass::gemm::GemmUniversalMode::kGemm; + + cutlass::gemm::GemmUniversalMode mode = + cutlass::gemm::GemmUniversalMode::kGemm; + // cutlass::gemm::BatchedGemmCoord problem_size = + // cutlass::gemm::BatchedGemmCoord{params.M, params.N, params.K, + // params.batch_count}; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(ElementCompute(params.scale), + ElementCompute(1.0)); + typename Gemm::Arguments arguments{ + mode, + problem_size, + params.batch_count, + epilogue_op, + reinterpret_cast(const_cast(params.A)), + reinterpret_cast(const_cast(params.B)), + reinterpret_cast(const_cast(params.bias)), + reinterpret_cast(params.D), + params.lda * params.M, + params.ldb * params.N, + (int64_t)0, + params.ldd * params.M, + params.lda, + params.ldb, + (int64_t)0, + params.ldd, + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::can_implement() failed" << std::endl; + return false; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + phi::Allocator* allocator = paddle::GetAllocator(params.place); + auto workspace = allocator->Allocate(workspace_size); + + // + // Run the GEMM + // + status = gemm_op(arguments, workspace->ptr(), params.stream); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::run() failed" << std::endl; + return false; + } + return true; +} + + +template +bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { + using ElementInputA = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementInputB = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementOutput = + typename std::conditional_t, + cutlass::bfloat16_t, + cutlass::half_t>; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementComputeEpilogue = float; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + static int const kAlignmentA = 16; + static int const kAlignmentB = 16; + + // This code section describes whether you want to use tensor cores or regular + // SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = SM; + + // This code section describes the tile size a thread block will compute + using ShapeMMAThreadBlock = ThreadBlockShape; + + // This code section describes tile size a warp will compute + using ShapeMMAWarp = WarpShape; + + // This code section describes the size of MMA op + using ShapeMMAOp = MMAShape; + + static constexpr auto ScaleType = + hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits:: + value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue, + ScaleType>; // <- data type for alpha/beta in linear + // combination function + + // Number of pipelines you want to use + constexpr int NumStages = Stages; + + using ConvertScaledOp = cutlass::epilogue::thread::Convert< + ElementAccumulator, + cutlass::gemm::device::DefaultGemmConfiguration::EpilogueOutputOp::kCount, + ElementAccumulator>; + + /// Reduction operator + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, typename EpilogueOp::ElementAccumulator, + EpilogueOp::kCount>; + + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastAccum; + + using Gemm = cutlass::gemm::device::GemmSplitKParallel; + + + cutlass::gemm::GemmCoord problem_size = + cutlass::gemm::GemmCoord{params.M, params.N, params.K}; + + ElementComputeEpilogue alpha = ElementComputeEpilogue(params.scale); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + // Split K dimension into 16 partitions + int split_k_slices = params.split_k; + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication + {reinterpret_cast(const_cast(params.A)),params.lda}, + {reinterpret_cast(const_cast(params.B)),params.ldb}, + {reinterpret_cast(const_cast(params.bias)),0}, + {reinterpret_cast(params.D),params.ldd}, + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::can_implement() failed" << std::endl; + return false; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + phi::Allocator* allocator = paddle::GetAllocator(params.place); + auto workspace = allocator->Allocate(workspace_size); + + // + // Run the GEMM + // + status = gemm_op(arguments, workspace->ptr(), params.stream); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::run() failed" << std::endl; + return false; + } + return true; +} \ No newline at end of file diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h new file mode 100644 index 000000000000..e42e4e1598db --- /dev/null +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h @@ -0,0 +1,311 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "fp8_fp8_gemm_scale_bias_act.h" // NOLINT + +#include "cutlass/cutlass.h" +#include "cutlass/float8.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_splitk_parallel.h" + +template +bool dispatch_fuse_gemm_noact(GemmEpilogueAllParams params) { + using ElementInputA = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementInputB = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementOutput = + typename std::conditional_t, + cutlass::bfloat16_t, + cutlass::half_t>; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementComputeEpilogue = float; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + static int const kAlignmentA = 16; + static int const kAlignmentB = 16; + + // This code section describes whether you want to use tensor cores or regular + // SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = SM; + + // This code section describes the tile size a thread block will compute + using ShapeMMAThreadBlock = ThreadBlockShape; + + // This code section describes tile size a warp will compute + using ShapeMMAWarp = WarpShape; + + // This code section describes the size of MMA op + using ShapeMMAOp = MMAShape; // <- MMA Op tile + + static constexpr auto ScaleType = + hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // <- ?? + + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits:: + value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue, + ScaleType>; // <- data type for alpha/beta in linear + // combination function + + // Number of pipelines you want to use + constexpr int NumStages = Stages; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + kAlignmentA, + kAlignmentB, + cutlass::arch::OpMultiplyAddFastAccum>; // NOLINT + + cutlass::gemm::GemmCoord problem_size = + cutlass::gemm::GemmCoord{params.M, params.N, params.K}; + // cutlass::gemm::GemmUniversalMode mode = + // cutlass::gemm::GemmUniversalMode::kGemm; + + cutlass::gemm::GemmUniversalMode mode = + cutlass::gemm::GemmUniversalMode::kGemm; + // cutlass::gemm::BatchedGemmCoord problem_size = + // cutlass::gemm::BatchedGemmCoord{params.M, params.N, params.K, + // params.batch_count}; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(ElementCompute(params.scale), + ElementCompute(1.0)); + typename Gemm::Arguments arguments{ + mode, + problem_size, + params.batch_count, + epilogue_op, + reinterpret_cast(const_cast(params.A)), + reinterpret_cast(const_cast(params.B)), + reinterpret_cast(const_cast(params.bias)), + reinterpret_cast(params.D), + params.lda * params.M, + params.ldb * params.N, + (int64_t)0, + params.ldd * params.M, + params.lda, + params.ldb, + (int64_t)0, + params.ldd, + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::can_implement() failed" << std::endl; + return false; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + phi::Allocator* allocator = paddle::GetAllocator(params.place); + auto workspace = allocator->Allocate(workspace_size); + + // + // Run the GEMM + // + status = gemm_op(arguments, workspace->ptr(), params.stream); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::run() failed" << std::endl; + return false; + } + return true; +} + + +template +bool dispatch_fuse_gemm_split_k_noact(GemmEpilogueAllParams params) { + using ElementInputA = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementInputB = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementOutput = + typename std::conditional_t, + cutlass::bfloat16_t, + cutlass::half_t>; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementComputeEpilogue = float; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + static int const kAlignmentA = 16; + static int const kAlignmentB = 16; + + // This code section describes whether you want to use tensor cores or regular + // SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = SM; + + // This code section describes the tile size a thread block will compute + using ShapeMMAThreadBlock = ThreadBlockShape; + + // This code section describes tile size a warp will compute + using ShapeMMAWarp = WarpShape; + + // This code section describes the size of MMA op + using ShapeMMAOp = MMAShape; + + static constexpr auto ScaleType = + hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits:: + value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue, + ScaleType>; // <- data type for alpha/beta in linear + // combination function + + // Number of pipelines you want to use + constexpr int NumStages = Stages; + + using ConvertScaledOp = cutlass::epilogue::thread::Convert< + ElementAccumulator, + cutlass::gemm::device::DefaultGemmConfiguration::EpilogueOutputOp::kCount, + ElementAccumulator>; + + /// Reduction operator + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, typename EpilogueOp::ElementAccumulator, + EpilogueOp::kCount>; + + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastAccum; + + using Gemm = cutlass::gemm::device::GemmSplitKParallel; + + + cutlass::gemm::GemmCoord problem_size = + cutlass::gemm::GemmCoord{params.M, params.N, params.K}; + + ElementComputeEpilogue alpha = ElementComputeEpilogue(params.scale); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + // Split K dimension into 16 partitions + int split_k_slices = params.split_k; + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication + {reinterpret_cast(const_cast(params.A)),params.lda}, + {reinterpret_cast(const_cast(params.B)),params.ldb}, + {reinterpret_cast(const_cast(params.bias)),0}, + {reinterpret_cast(params.D),params.ldd}, + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::can_implement() failed" << std::endl; + return false; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + phi::Allocator* allocator = paddle::GetAllocator(params.place); + auto workspace = allocator->Allocate(workspace_size); + + // + // Run the GEMM + // + status = gemm_op(arguments, workspace->ptr(), params.stream); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::run() failed" << std::endl; + return false; + } + return true; +} \ No newline at end of file diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h new file mode 100644 index 000000000000..cf1269aa931d --- /dev/null +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h @@ -0,0 +1,311 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "fp8_fp8_gemm_scale_bias_act.h" // NOLINT + +#include "cutlass/cutlass.h" +#include "cutlass/float8.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_splitk_parallel.h" + +template +bool dispatch_fuse_gemm_relu(GemmEpilogueAllParams params) { + using ElementInputA = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementInputB = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementOutput = + typename std::conditional_t, + cutlass::bfloat16_t, + cutlass::half_t>; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementComputeEpilogue = float; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + static int const kAlignmentA = 16; + static int const kAlignmentB = 16; + + // This code section describes whether you want to use tensor cores or regular + // SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = SM; + + // This code section describes the tile size a thread block will compute + using ShapeMMAThreadBlock = ThreadBlockShape; + + // This code section describes tile size a warp will compute + using ShapeMMAWarp = WarpShape; + + // This code section describes the size of MMA op + using ShapeMMAOp = MMAShape; // <- MMA Op tile + + static constexpr auto ScaleType = + hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // <- ?? + + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits:: + value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue, + ScaleType>; // <- data type for alpha/beta in linear + // combination function + + // Number of pipelines you want to use + constexpr int NumStages = Stages; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + kAlignmentA, + kAlignmentB, + cutlass::arch::OpMultiplyAddFastAccum>; // NOLINT + + cutlass::gemm::GemmCoord problem_size = + cutlass::gemm::GemmCoord{params.M, params.N, params.K}; + // cutlass::gemm::GemmUniversalMode mode = + // cutlass::gemm::GemmUniversalMode::kGemm; + + cutlass::gemm::GemmUniversalMode mode = + cutlass::gemm::GemmUniversalMode::kGemm; + // cutlass::gemm::BatchedGemmCoord problem_size = + // cutlass::gemm::BatchedGemmCoord{params.M, params.N, params.K, + // params.batch_count}; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(ElementCompute(params.scale), + ElementCompute(1.0)); + typename Gemm::Arguments arguments{ + mode, + problem_size, + params.batch_count, + epilogue_op, + reinterpret_cast(const_cast(params.A)), + reinterpret_cast(const_cast(params.B)), + reinterpret_cast(const_cast(params.bias)), + reinterpret_cast(params.D), + params.lda * params.M, + params.ldb * params.N, + (int64_t)0, + params.ldd * params.M, + params.lda, + params.ldb, + (int64_t)0, + params.ldd, + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::can_implement() failed" << std::endl; + return false; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + phi::Allocator* allocator = paddle::GetAllocator(params.place); + auto workspace = allocator->Allocate(workspace_size); + + // + // Run the GEMM + // + status = gemm_op(arguments, workspace->ptr(), params.stream); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::run() failed" << std::endl; + return false; + } + return true; +} + + +template +bool dispatch_fuse_gemm_split_k_relu(GemmEpilogueAllParams params) { + using ElementInputA = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementInputB = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementOutput = + typename std::conditional_t, + cutlass::bfloat16_t, + cutlass::half_t>; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementComputeEpilogue = float; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + static int const kAlignmentA = 16; + static int const kAlignmentB = 16; + + // This code section describes whether you want to use tensor cores or regular + // SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = SM; + + // This code section describes the tile size a thread block will compute + using ShapeMMAThreadBlock = ThreadBlockShape; + + // This code section describes tile size a warp will compute + using ShapeMMAWarp = WarpShape; + + // This code section describes the size of MMA op + using ShapeMMAOp = MMAShape; + + static constexpr auto ScaleType = + hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits:: + value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue, + ScaleType>; // <- data type for alpha/beta in linear + // combination function + + // Number of pipelines you want to use + constexpr int NumStages = Stages; + + using ConvertScaledOp = cutlass::epilogue::thread::Convert< + ElementAccumulator, + cutlass::gemm::device::DefaultGemmConfiguration::EpilogueOutputOp::kCount, + ElementAccumulator>; + + /// Reduction operator + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, typename EpilogueOp::ElementAccumulator, + EpilogueOp::kCount>; + + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastAccum; + + using Gemm = cutlass::gemm::device::GemmSplitKParallel; + + + cutlass::gemm::GemmCoord problem_size = + cutlass::gemm::GemmCoord{params.M, params.N, params.K}; + + ElementComputeEpilogue alpha = ElementComputeEpilogue(params.scale); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + // Split K dimension into 16 partitions + int split_k_slices = params.split_k; + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication + {reinterpret_cast(const_cast(params.A)),params.lda}, + {reinterpret_cast(const_cast(params.B)),params.ldb}, + {reinterpret_cast(const_cast(params.bias)),0}, + {reinterpret_cast(params.D),params.ldd}, + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::can_implement() failed" << std::endl; + return false; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + phi::Allocator* allocator = paddle::GetAllocator(params.place); + auto workspace = allocator->Allocate(workspace_size); + + // + // Run the GEMM + // + status = gemm_op(arguments, workspace->ptr(), params.stream); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::run() failed" << std::endl; + return false; + } + return true; +} \ No newline at end of file diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale.h deleted file mode 100644 index 83b91903b64b..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include "fp8_fp8_gemm_scale_bias_act.h" // NOLINT - -#include "cutlass/cutlass.h" -#include "cutlass/float8.h" -#include "cutlass/gemm/device/gemm_universal.h" - -template -bool dispatch_gemm_scale(GemmEpilogueAllParams params) { - using ElementInputA = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputB = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementOutput = - typename std::conditional_t, - cutlass::bfloat16_t, - cutlass::half_t>; - - using ElementAccumulator = float; - using ElementCompute = float; - using ElementComputeEpilogue = float; - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - static int const kAlignmentA = 16; - static int const kAlignmentB = 16; - - // This code section describes whether you want to use tensor cores or regular - // SIMT cores on GPU SM - using MMAOp = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm89; - - // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<64, 64, 64>; - - // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<32, 32, 64>; - - // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; // <- MMA Op tile M = - // 16, N = 8, K = 32 - - // This code section describes how threadblocks are scheduled on GPU - using SwizzleThreadBlock = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // <- ?? - - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits:: - value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType:: - OnlyAlphaScaling>; // <- data type for alpha/beta in linear - // combination function - - // Number of pipelines you want to use - constexpr int NumStages = 4; - - using Gemm = cutlass::gemm::device::GemmUniversal< - ElementInputA, - LayoutInputA, - ElementInputB, - LayoutInputB, - ElementOutput, - LayoutOutput, - ElementAccumulator, - MMAOp, - SmArch, - ShapeMMAThreadBlock, - ShapeMMAWarp, - ShapeMMAOp, - EpilogueOp, - SwizzleThreadBlock, - NumStages, - kAlignmentA, - kAlignmentB, - cutlass::arch::OpMultiplyAddFastAccum>; // NOLINT - - cutlass::gemm::GemmCoord problem_size = - cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - // cutlass::gemm::GemmUniversalMode mode = - // cutlass::gemm::GemmUniversalMode::kGemm; - - cutlass::gemm::GemmUniversalMode mode = - cutlass::gemm::GemmUniversalMode::kGemm; - // cutlass::gemm::BatchedGemmCoord problem_size = - // cutlass::gemm::BatchedGemmCoord{params.M, params.N, params.K, - // params.batch_count}; - - using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; - typename EpilogueOutputOp::Params epilogue_op(ElementCompute(params.scale), - ElementCompute(1.0)); - typename Gemm::Arguments arguments{ - mode, - problem_size, - params.batch_count, - epilogue_op, - reinterpret_cast(const_cast(params.A)), - reinterpret_cast(const_cast(params.B)), - nullptr, - reinterpret_cast(params.D), - params.lda * params.M, - params.ldb * params.N, - (int64_t)0, - params.ldd * params.M, - params.lda, - params.ldb, - (int64_t)0, - params.ldd, - }; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::can_implement() failed" << std::endl; - return false; - } - - size_t workspace_size = Gemm::get_workspace_size(arguments); - phi::Allocator* allocator = paddle::GetAllocator(params.place); - auto workspace = allocator->Allocate(workspace_size); - - // - // Run the GEMM - // - status = gemm_op(arguments, workspace->ptr(), params.stream); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::run() failed" << std::endl; - return false; - } - return true; -} - diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias.h deleted file mode 100644 index d330305dc424..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias.h +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include "fp8_fp8_gemm_scale_bias_act.h" // NOLINT - -#include "cutlass/cutlass.h" -#include "cutlass/float8.h" -#include "cutlass/gemm/device/gemm_universal.h" - - -template -bool dispatch_gemm_scale_bias(GemmEpilogueAllParams params) { - using ElementInputA = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputB = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementOutput = - typename std::conditional_t, - cutlass::bfloat16_t, - cutlass::half_t>; - - using ElementAccumulator = float; - using ElementCompute = float; - using ElementComputeEpilogue = float; - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - static int const kAlignmentA = 16; - static int const kAlignmentB = 16; - - // This code section describes whether you want to use tensor cores or regular - // SIMT cores on GPU SM - using MMAOp = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm89; - - // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<64, 64, 64>; - - // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<32, 32, 64>; - - // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; // <- MMA Op tile M = - // 16, N = 8, K = 32 - - // This code section describes how threadblocks are scheduled on GPU - using SwizzleThreadBlock = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // <- ?? - - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits:: - value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType:: - NoBetaScaling>; // <- data type for alpha/beta in linear combination - // function - - // Number of pipelines you want to use - constexpr int NumStages = 4; - - using Gemm = cutlass::gemm::device::GemmUniversal< - ElementInputA, - LayoutInputA, - ElementInputB, - LayoutInputB, - ElementOutput, - LayoutOutput, - ElementAccumulator, - MMAOp, - SmArch, - ShapeMMAThreadBlock, - ShapeMMAWarp, - ShapeMMAOp, - EpilogueOp, - SwizzleThreadBlock, - NumStages, - kAlignmentA, - kAlignmentB, - cutlass::arch::OpMultiplyAddFastAccum>; // NOLINT - - cutlass::gemm::GemmCoord problem_size = - cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - // cutlass::gemm::GemmUniversalMode mode = - // cutlass::gemm::GemmUniversalMode::kGemm; - - cutlass::gemm::GemmUniversalMode mode = - cutlass::gemm::GemmUniversalMode::kGemm; - // cutlass::gemm::BatchedGemmCoord problem_size = - // cutlass::gemm::BatchedGemmCoord{params.M, params.N, params.K, - // params.batch_count}; - - using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; - typename EpilogueOutputOp::Params epilogue_op(ElementCompute(params.scale), - ElementCompute(1.0)); - typename Gemm::Arguments arguments{ - mode, - problem_size, - params.batch_count, - epilogue_op, - reinterpret_cast(const_cast(params.A)), - reinterpret_cast(const_cast(params.B)), - reinterpret_cast(const_cast(params.bias)), - reinterpret_cast(params.D), - params.lda * params.M, - params.ldb * params.N, - (int64_t)0, - params.ldd * params.M, - params.lda, - params.ldb, - (int64_t)0, - params.ldd, - }; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::can_implement() failed" << std::endl; - return false; - } - - size_t workspace_size = Gemm::get_workspace_size(arguments); - phi::Allocator* allocator = paddle::GetAllocator(params.place); - auto workspace = allocator->Allocate(workspace_size); - - // - // Run the GEMM - // - status = gemm_op(arguments, workspace->ptr(), params.stream); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::run() failed" << std::endl; - return false; - } - return true; -} - diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias_gelu.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias_gelu.h deleted file mode 100644 index 9e0d1b37a4e4..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias_gelu.h +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include "fp8_fp8_gemm_scale_bias_act.h" // NOLINT - -#include "cutlass/cutlass.h" -#include "cutlass/float8.h" -#include "cutlass/gemm/device/gemm_universal.h" - - -template -bool dispatch_gemm_scale_bias_gelu(GemmEpilogueAllParams params) { - using ElementInputA = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputB = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementOutput = - typename std::conditional_t, - cutlass::bfloat16_t, - cutlass::half_t>; - - using ElementAccumulator = float; - using ElementCompute = float; - using ElementComputeEpilogue = float; - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - static int const kAlignmentA = 16; - static int const kAlignmentB = 16; - - // This code section describes whether you want to use tensor cores or regular - // SIMT cores on GPU SM - using MMAOp = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm89; - - // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<64, 64, 64>; - - // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<32, 32, 64>; - - // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; // <- MMA Op tile M = - // 16, N = 8, K = 32 - - // This code section describes how threadblocks are scheduled on GPU - using SwizzleThreadBlock = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // <- ?? - - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits:: - value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType:: - NoBetaScaling>; // <- data type for alpha/beta in linear combination - // function - - // Number of pipelines you want to use - constexpr int NumStages = 4; - - using Gemm = cutlass::gemm::device::GemmUniversal< - ElementInputA, - LayoutInputA, - ElementInputB, - LayoutInputB, - ElementOutput, - LayoutOutput, - ElementAccumulator, - MMAOp, - SmArch, - ShapeMMAThreadBlock, - ShapeMMAWarp, - ShapeMMAOp, - EpilogueOp, - SwizzleThreadBlock, - NumStages, - kAlignmentA, - kAlignmentB, - cutlass::arch::OpMultiplyAddFastAccum>; // NOLINT - - cutlass::gemm::GemmCoord problem_size = - cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - // cutlass::gemm::GemmUniversalMode mode = - // cutlass::gemm::GemmUniversalMode::kGemm; - - cutlass::gemm::GemmUniversalMode mode = - cutlass::gemm::GemmUniversalMode::kGemm; - // cutlass::gemm::BatchedGemmCoord problem_size = - // cutlass::gemm::BatchedGemmCoord{params.M, params.N, params.K, - // params.batch_count}; - - using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; - typename EpilogueOutputOp::Params epilogue_op(ElementCompute(params.scale), - ElementCompute(1.0)); - typename Gemm::Arguments arguments{ - mode, - problem_size, - params.batch_count, - epilogue_op, - reinterpret_cast(const_cast(params.A)), - reinterpret_cast(const_cast(params.B)), - reinterpret_cast(const_cast(params.bias)), - reinterpret_cast(params.D), - params.lda * params.M, - params.ldb * params.N, - (int64_t)0, - params.ldd * params.M, - params.lda, - params.ldb, - (int64_t)0, - params.ldd, - }; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::can_implement() failed" << std::endl; - return false; - } - - size_t workspace_size = Gemm::get_workspace_size(arguments); - phi::Allocator* allocator = paddle::GetAllocator(params.place); - auto workspace = allocator->Allocate(workspace_size); - - // - // Run the GEMM - // - status = gemm_op(arguments, workspace->ptr(), params.stream); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::run() failed" << std::endl; - return false; - } - return true; -} - diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias_relu.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias_relu.h deleted file mode 100644 index 0d5b7dfdbf57..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_bias_relu.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include "fp8_fp8_gemm_scale_bias_act.h" // NOLINT - -#include "cutlass/cutlass.h" -#include "cutlass/float8.h" -#include "cutlass/gemm/device/gemm_universal.h" - -template -bool dispatch_gemm_scale_bias_relu(GemmEpilogueAllParams params) { - using ElementInputA = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputB = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementOutput = - typename std::conditional_t, - cutlass::bfloat16_t, - cutlass::half_t>; - - using ElementAccumulator = float; - using ElementCompute = float; - using ElementComputeEpilogue = float; - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - static int const kAlignmentA = 16; - static int const kAlignmentB = 16; - - // This code section describes whether you want to use tensor cores or regular - // SIMT cores on GPU SM - using MMAOp = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm89; - - // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<64, 64, 64>; - - // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<32, 32, 64>; - - // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; // <- MMA Op tile M = - // 16, N = 8, K = 32 - - // This code section describes how threadblocks are scheduled on GPU - using SwizzleThreadBlock = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // <- ?? - - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits:: - value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType:: - NoBetaScaling>; // <- data type for alpha/beta in linear combination - // function - - // Number of pipelines you want to use - constexpr int NumStages = 4; - - using Gemm = cutlass::gemm::device::GemmUniversal< - ElementInputA, - LayoutInputA, - ElementInputB, - LayoutInputB, - ElementOutput, - LayoutOutput, - ElementAccumulator, - MMAOp, - SmArch, - ShapeMMAThreadBlock, - ShapeMMAWarp, - ShapeMMAOp, - EpilogueOp, - SwizzleThreadBlock, - NumStages, - kAlignmentA, - kAlignmentB, - cutlass::arch::OpMultiplyAddFastAccum>; // NOLINT - - cutlass::gemm::GemmCoord problem_size = - cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - // cutlass::gemm::GemmUniversalMode mode = - // cutlass::gemm::GemmUniversalMode::kGemm; - - cutlass::gemm::GemmUniversalMode mode = - cutlass::gemm::GemmUniversalMode::kGemm; - // cutlass::gemm::BatchedGemmCoord problem_size = - // cutlass::gemm::BatchedGemmCoord{params.M, params.N, params.K, - // params.batch_count}; - - using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; - typename EpilogueOutputOp::Params epilogue_op(ElementCompute(params.scale), - ElementCompute(1.0)); - typename Gemm::Arguments arguments{ - mode, - problem_size, - params.batch_count, - epilogue_op, - reinterpret_cast(const_cast(params.A)), - reinterpret_cast(const_cast(params.B)), - reinterpret_cast(const_cast(params.bias)), - reinterpret_cast(params.D), - params.lda * params.M, - params.ldb * params.N, - (int64_t)0, - params.ldd * params.M, - params.lda, - params.ldb, - (int64_t)0, - params.ldd, - }; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::can_implement() failed" << std::endl; - return false; - } - - size_t workspace_size = Gemm::get_workspace_size(arguments); - phi::Allocator* allocator = paddle::GetAllocator(params.place); - auto workspace = allocator->Allocate(workspace_size); - - // - // Run the GEMM - // - status = gemm_op(arguments, workspace->ptr(), params.stream); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::run() failed" << std::endl; - return false; - } - return true; -} - diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_gelu.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_gelu.h deleted file mode 100644 index 2e0d1497f507..000000000000 --- a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/gemm_scale_gelu.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include "fp8_fp8_gemm_scale_bias_act.h" // NOLINT - -#include "cutlass/cutlass.h" -#include "cutlass/float8.h" -#include "cutlass/gemm/device/gemm_universal.h" - -template -bool dispatch_gemm_scale_gelu(GemmEpilogueAllParams params) { - using ElementInputA = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementInputB = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementOutput = - typename std::conditional_t, - cutlass::bfloat16_t, - cutlass::half_t>; - - using ElementAccumulator = float; - using ElementCompute = float; - using ElementComputeEpilogue = float; - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - static int const kAlignmentA = 16; - static int const kAlignmentB = 16; - - // This code section describes whether you want to use tensor cores or regular - // SIMT cores on GPU SM - using MMAOp = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm89; - - // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<64, 64, 64>; - - // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<32, 32, 64>; - - // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; // <- MMA Op tile M = - // 16, N = 8, K = 32 - - // This code section describes how threadblocks are scheduled on GPU - using SwizzleThreadBlock = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // <- ?? - - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits:: - value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType:: - OnlyAlphaScaling>; // <- data type for alpha/beta in linear - // combination function - - // Number of pipelines you want to use - constexpr int NumStages = 4; - - using Gemm = cutlass::gemm::device::GemmUniversal< - ElementInputA, - LayoutInputA, - ElementInputB, - LayoutInputB, - ElementOutput, - LayoutOutput, - ElementAccumulator, - MMAOp, - SmArch, - ShapeMMAThreadBlock, - ShapeMMAWarp, - ShapeMMAOp, - EpilogueOp, - SwizzleThreadBlock, - NumStages, - kAlignmentA, - kAlignmentB, - cutlass::arch::OpMultiplyAddFastAccum>; // NOLINT - - cutlass::gemm::GemmCoord problem_size = - cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - // cutlass::gemm::GemmUniversalMode mode = - // cutlass::gemm::GemmUniversalMode::kGemm; - - cutlass::gemm::GemmUniversalMode mode = - cutlass::gemm::GemmUniversalMode::kGemm; - // cutlass::gemm::BatchedGemmCoord problem_size = - // cutlass::gemm::BatchedGemmCoord{params.M, params.N, params.K, - // params.batch_count}; - - using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; - typename EpilogueOutputOp::Params epilogue_op(ElementCompute(params.scale), - ElementCompute(1.0)); - typename Gemm::Arguments arguments{ - mode, - problem_size, - params.batch_count, - epilogue_op, - reinterpret_cast(const_cast(params.A)), - reinterpret_cast(const_cast(params.B)), - nullptr, - reinterpret_cast(params.D), - params.lda * params.M, - params.ldb * params.N, - (int64_t)0, - params.ldd * params.M, - params.lda, - params.ldb, - (int64_t)0, - params.ldd, - }; - - Gemm gemm_op; - - cutlass::Status status = gemm_op.can_implement(arguments); - - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::can_implement() failed" << std::endl; - return false; - } - - size_t workspace_size = Gemm::get_workspace_size(arguments); - phi::Allocator* allocator = paddle::GetAllocator(params.place); - auto workspace = allocator->Allocate(workspace_size); - - // - // Run the GEMM - // - status = gemm_op(arguments, workspace->ptr(), params.stream); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Gemm::run() failed" << std::endl; - return false; - } - return true; -} - diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_common.h b/csrc/gpu/fp8_gemm_with_cutlass/fp8_common.h index 34aa575278d7..59c4862fdf81 100644 --- a/csrc/gpu/fp8_gemm_with_cutlass/fp8_common.h +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_common.h @@ -18,11 +18,13 @@ #include "cuda.h" // NOLINT #include "helper.h" +#include "cutlass_helper.h" #include "paddle/extension.h" #include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/allocator.h" +#include "paddle/common/flags.h" typedef struct { @@ -39,11 +41,12 @@ typedef struct { int batch_count = 1; const phi::GPUPlace &place; cudaStream_t stream; - int sm_version = 80; + int sm_version = 89; float leaky_alpha = 1.0; - const void *bias; + const void *bias = nullptr; std::vector &bias_dims; - std::string &gemm_config; + std::string &fuse_gemm_config; + int split_k = 1; } GemmEpilogueAllParams; typedef bool (*func)(GemmEpilogueAllParams); @@ -65,13 +68,12 @@ typedef struct { int batch_count = 1; const phi::GPUPlace &place; cudaStream_t stream; - int sm_version = 80; - const void *bias0; - const void *bias1; + int sm_version = 89; + const void *bias0 = nullptr; + const void *bias1 = nullptr; std::vector &bias_dims0; std::vector &bias_dims1; - std::string &gemm_config; + std::string &fuse_gemm_config; } DualGemmEpilogueAllParams; typedef bool (*func1)(DualGemmEpilogueAllParams); - diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu index e84198514a8a..39d4869b5d00 100644 --- a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu @@ -114,7 +114,7 @@ std::vector cutlass_fp8_fp8_fp8_dual_gemm( } std::string act = (activation_type == "") ? "swiglu" : activation_type; - std::string gemm_config = + std::string fuse_gemm_config = input_dtype + "_" + output_dtype + "_" + isbias + bias_dtype + act; DualGemmEpilogueAllParams params = { @@ -139,7 +139,7 @@ std::vector cutlass_fp8_fp8_fp8_dual_gemm( bias_data1, bias_dims0, bias_dims1, - gemm_config}; + fuse_gemm_config}; fp8_fp8_dual_gemm_scale_bias_act(params); return {out}; @@ -232,4 +232,3 @@ PD_BUILD_OP(cutlass_fp8_fp8_fp8_dual_gemm_fused) .SetKernelFn(PD_KERNEL(cutlass_fp8_fp8_fp8_dual_gemm)) .SetInferShapeFn(PD_INFER_SHAPE(CutlassFp8Fp8Fp8DualGemmFusedInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(CutlassFp8Fp8Fp8DualGemmFusedInferDtype)); - diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu index 090d40de430e..3e42b065fbc9 100644 --- a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu @@ -68,11 +68,11 @@ std::vector cutlass_fp8_fp8_half_gemm( std::string input_dtype = ""; if (x.dtype() == phi::DataType::FLOAT8_E4M3FN) { - input_dtype = "e4m3"; + input_dtype = "float8_e4m3fn"; x_ptr = reinterpret_cast(x.data()); y_ptr = reinterpret_cast(y.data()); } else if (x.dtype() == phi::DataType::FLOAT8_E5M2) { - input_dtype = "e5m2"; + input_dtype = "float8_e5m2"; x_ptr = reinterpret_cast(x.data()); y_ptr = reinterpret_cast(y.data()); } else { @@ -84,27 +84,24 @@ std::vector cutlass_fp8_fp8_half_gemm( out_shape[rank - 1] = N; out_shape[rank - 2] = M; - std::string cutlass_output_dtype = ""; if (output_dtype == "bfloat16") { out = paddle::empty(out_shape, paddle::DataType::BFLOAT16, x.place()); out_ptr = reinterpret_cast(out.data()); - cutlass_output_dtype = std::string("bf16"); } else if (output_dtype == "float16") { out = paddle::empty(out_shape, paddle::DataType::FLOAT16, x.place()); out_ptr = reinterpret_cast(out.data()); - cutlass_output_dtype = std::string("fp16"); } else { PADDLE_THROW(phi::errors::Fatal( "fp8_fp8_half_gemm_fused only support bfloat16 and float16 output")); } - std::string isbias = bias ? "bias_" : ""; + std::string isbias = bias ? "true" : "false"; std::string act = (activation_type == "" || activation_type == "identity") - ? "identity" + ? "noact" : activation_type; - std::string gemm_config = - input_dtype + "_" + cutlass_output_dtype + "_" + isbias + act; + std::string fuse_gemm_config = + input_dtype + "_" + output_dtype + "_" + isbias + "_" + act; void* bias_data = nullptr; std::vector bias_dims{}; @@ -137,7 +134,7 @@ std::vector cutlass_fp8_fp8_half_gemm( 0.01, // for leaky_relu bias_data, bias_dims, - gemm_config}; + fuse_gemm_config}; fp8_fp8_gemm_scale_bias_act(params); return {out}; } @@ -206,4 +203,3 @@ PD_BUILD_OP(cutlass_fp8_fp8_half_gemm_fused) .SetKernelFn(PD_KERNEL(cutlass_fp8_fp8_half_gemm)) .SetInferShapeFn(PD_INFER_SHAPE(CutlassFp8Fp8HalfGemmFusedInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(CutlassFp8Fp8HalfGemmFusedInferDtype)); - diff --git a/csrc/gpu/helper.h b/csrc/gpu/helper.h index d7884e279757..ceccbd4ee4a0 100644 --- a/csrc/gpu/helper.h +++ b/csrc/gpu/helper.h @@ -15,6 +15,14 @@ #pragma once #include "paddle/extension.h" +#include +#include +#include +#include +#include +#include +#include +#include #ifdef PADDLE_WITH_HIP #include #include @@ -27,6 +35,11 @@ namespace cub = hipcub; #include #include #endif +#include +#include +#include "nlohmann/json.hpp" + +using json = nlohmann::json; constexpr int kBlockSize = 256; constexpr int kNumWaves = 16; @@ -146,4 +159,15 @@ HOSTDEVICE inline void Store(const AlignedVector& vec, T* addr) { *addr_vec = vec; } -constexpr int VEC_16B = 16; \ No newline at end of file +constexpr int VEC_16B = 16; + +inline json readJsonFromFile(const std::string& filePath) { + std::ifstream file(filePath); + if (!file.is_open()) { + throw std::runtime_error("Unable to open file: " + filePath); + } + + json j; + file >> j; + return j; +} \ No newline at end of file diff --git a/csrc/gpu/test_fp8_gemm.py b/csrc/gpu/test_fp8_gemm.py new file mode 100644 index 000000000000..c89d6afb1338 --- /dev/null +++ b/csrc/gpu/test_fp8_gemm.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import paddle +from paddlenlp_ops import cutlass_fp8_fp8_half_gemm_fused + + +def setup_args(): + """Setup export arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--m_min", type=int, help="range of gemm shape: m_min") + parser.add_argument("--m_max", type=int, help="range of gemm shape: m_max") + parser.add_argument("--n", nargs="+", type=int, help="List of gemm shape: n") + parser.add_argument("--k", nargs="+", type=int, help="List of gemm shape: k") + args = parser.parse_args() + return args + + +def gemm(m, n, k): + A = paddle.ones([m, k], dtype="float8_e4m3fn") + B = paddle.ones([n, k], dtype="float8_e4m3fn") + res = cutlass_fp8_fp8_half_gemm_fused( + A, B, bias=None, transpose_x=False, transpose_y=True, output_dtype="bfloat16", scale=0.5, act="identity" + ) + print(f"m: {m}, n: {n}, k: {k}") + print(res) + + +if __name__ == "__main__": + args = setup_args() + + m_min = args.m_min + m_max = args.m_max + ns = args.n + ks = args.k + + for m in range(m_min, m_max, 32): + for n in ns: + for k in ks: + gemm(m, n, k) + paddle.device.cuda.empty_cache() diff --git a/csrc/gpu/test_fp8gemm.py b/csrc/gpu/test_fp8gemm.py deleted file mode 100644 index b41ba2a93b82..000000000000 --- a/csrc/gpu/test_fp8gemm.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import paddle - -if os.getenv("FLAGS_CUTLASS_FP8_GEMM", "True") == "True": - from paddlenlp_ops import cutlass_fp8_fp8_half_gemm_fused as fp8_gemm_fused -else: - from paddle.linalg import fp8_fp8_half_gemm_fused as fp8_gemm_fused - -A = paddle.ones([2, 32, 64], dtype="float8_e4m3fn") -B = paddle.ones([2, 32, 64], dtype="float8_e4m3fn") - -res0 = fp8_gemm_fused( - A, - B, - bias=None, - transpose_x=False, - transpose_y=True, - output_dtype="float16", - scale=0.5, - act="identity", -) -print("res0: ", res0) - -A = paddle.ones([2, 32, 64], dtype="float8_e4m3fn") -B = paddle.ones([2, 128, 64], dtype="float8_e4m3fn") - -res1 = fp8_gemm_fused( - A, - B, - bias=None, - transpose_x=False, - transpose_y=True, - output_dtype="bfloat16", - scale=0.5, - act="identity", -) - -A = paddle.ones([2, 32, 64], dtype="float32") -B = paddle.ones([2, 128, 64], dtype="float32") -expect_result = 0.5 * paddle.matmul(A, B.transpose([0, 2, 1])) - -result0 = paddle.equal_all( - paddle.cast(res0, "float32"), - paddle.to_tensor(expect_result), -) - -result1 = paddle.equal_all( - paddle.cast(res1, "float32"), - paddle.to_tensor(expect_result), -) - -print("result0: ", result0) -print("result1: ", result1) diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 641255197b0b..a06eb6b9e760 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -32,6 +32,15 @@ def clone_git_repo(version, repo_url, destination_path): return False +def find_end_files(directory, end_str): + gen_files = [] + for root, dirs, files in os.walk(directory): + for file in files: + if file.endswith(end_str): + gen_files.append(os.path.join(root, file)) + return gen_files + + def get_sm_version(): prop = paddle.device.cuda.get_device_properties() cc = prop.major * 10 + prop.minor @@ -50,6 +59,7 @@ def strtobool(v): f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." ) + def get_gencode_flags(): if not strtobool(os.getenv("FLAG_LLM_PDC", "False")): prop = paddle.device.cuda.get_device_properties() @@ -99,7 +109,7 @@ def get_gencode_flags(): "./gpu/tune_cublaslt_gemm.cu", ] -cutlass_dir = "gpu/cutlass_kernels/cutlass" +cutlass_dir = "third_party/cutlass" nvcc_compile_args = gencode_flags if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir): @@ -107,6 +117,12 @@ def get_gencode_flags(): os.makedirs(cutlass_dir) clone_git_repo("v3.5.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir) +json_dir = "third_party/nlohmann_json" +if not os.path.exists(json_dir) or not os.listdir(json_dir): + if not os.path.exists(json_dir): + os.makedirs(json_dir) + clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir) + nvcc_compile_args += [ "-O3", "-U__CUDA_NO_HALF_OPERATORS__", @@ -116,7 +132,8 @@ def get_gencode_flags(): "-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", "-Igpu/cutlass_kernels", - "-Igpu/cutlass_kernels/cutlass/include", + "-Ithird_party/cutlass/include", + "-Ithird_party/nlohmann_json/single_include", "-Igpu/fp8_gemm_with_cutlass", "-Igpu", ] @@ -125,6 +142,7 @@ def get_gencode_flags(): sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"] if cc >= 89: + sources += find_end_files("gpu/cutlass_kernels/fp8_gemm_fused/autogen", ".cu") sources += [ "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu", "gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.cu", diff --git a/csrc/tune_fp8_gemm.sh b/csrc/tune_fp8_gemm.sh new file mode 100644 index 000000000000..7089c5adb6cb --- /dev/null +++ b/csrc/tune_fp8_gemm.sh @@ -0,0 +1,29 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# llama2-7B +# nohup python ./gpu/test_fp8_gemm.py \ +# --m_min 32 \ +# --m_max 2049 \ +# --n 4096 12288 \ +# --k 4096 11008 \ +# > tune_gemm.log 2>&1 & + +# llama3-8B +nohup python ./gpu/test_fp8_gemm.py \ + --m_min 32 \ + --m_max 32768 \ + --n 4096 6144 \ + --k 4096 14336 \ + > tune_gemm.log 2>&1 & \ No newline at end of file