From 541f9f2d8aef9697fd7ccb6a7c0644da273f33b6 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Thu, 28 Oct 2021 19:08:41 -0700 Subject: [PATCH] [BYOC] CUTLASS integration (#9261) * byoc cutlass * add cmake and fix build * test worked but accuracy is bad * fixed argument printing properly * moving files * moving contents of cutlass_profiler into python/tvm/contrib/cutlass * run black * remove irrelavant codegen code * clang format * tried replacing sm 75 with 80, didn't help improve accuracy * remove irrelavant code from generator * tried dense + bias fusion but generated cu file does not compile * dense + bias worked after adding Leyuan's patch, bias + relu worked too * tried adding sm80 generator but accuracy is still off * remove GemmUniversal generator * cleanup partition and build * moved partition, profile and build function out of test * turned out the result match's TVM non-cutlass result. Numpy fp16 matmul is busted? * clean up test * LinearCombination can be reused for bias only epilogue * remove unsupported epilogues like gelu * removing deadcode * unify gemm templates for with or without beta scaling * supported gelu but accuracy is slightly off * gelu test passed with relaxed rtol * cleanup * remove unused stuff from library.py * move profiler template into its own file * removed gemm_profiler.py * move contents of compile_engine.py into gen_gemm.py * rename to profiler_template.cu to avoid CI issue * cleaning up trying to pass pylint * add missing asf header * run black * fixing many pylint issues except wildcard import * fixed wildcard warning * add missing CUTLASS.cmake file, restore gemm_profiler.py * pylint * minor fix * add license * start filling in TODO doc * rename GemmProfiler to GemmProfilerEmitter * more renaming and doc * add doc to the main compile API * refactored generator * run black * black fix * finish doc TODO * add test for 32 bit accum * fixed kernel generator to correctly handle fp32 accum * revise build-related API * add option to profile only one kernel * add option to enable parallel compilation * clean up gen_gemm * doc update * profile_cutlass_kernels -> tune_cutlass_kernels Co-authored-by: leyuan.wang Co-authored-by: Masahiro Masuda --- .gitmodules | 3 + 3rdparty/cutlass | 1 + CMakeLists.txt | 2 + LICENSE | 5 + cmake/modules/contrib/CUTLASS.cmake | 23 ++ licenses/LICENSE.cutlass.txt | 23 ++ python/tvm/contrib/cutlass/__init__.py | 18 + python/tvm/contrib/cutlass/build.py | 172 ++++++++ python/tvm/contrib/cutlass/gemm_operation.py | 262 +++++++++++++ python/tvm/contrib/cutlass/gemm_profiler.py | 196 ++++++++++ python/tvm/contrib/cutlass/gen_gemm.py | 355 +++++++++++++++++ python/tvm/contrib/cutlass/library.py | 219 +++++++++++ python/tvm/relay/op/contrib/__init__.py | 1 + python/tvm/relay/op/contrib/cutlass.py | 74 ++++ .../backend/contrib/codegen_c/codegen_c.h | 6 + src/relay/backend/contrib/cutlass/codegen.cc | 369 ++++++++++++++++++ src/relay/backend/contrib/dnnl/codegen.cc | 6 - tests/python/contrib/test_cutlass.py | 135 +++++++ 18 files changed, 1864 insertions(+), 6 deletions(-) create mode 160000 3rdparty/cutlass create mode 100644 cmake/modules/contrib/CUTLASS.cmake create mode 100644 licenses/LICENSE.cutlass.txt create mode 100644 python/tvm/contrib/cutlass/__init__.py create mode 100644 python/tvm/contrib/cutlass/build.py create mode 100644 python/tvm/contrib/cutlass/gemm_operation.py create mode 100644 python/tvm/contrib/cutlass/gemm_profiler.py create mode 100644 python/tvm/contrib/cutlass/gen_gemm.py create mode 100644 python/tvm/contrib/cutlass/library.py create mode 100644 python/tvm/relay/op/contrib/cutlass.py create mode 100644 src/relay/backend/contrib/cutlass/codegen.cc create mode 100644 tests/python/contrib/test_cutlass.py diff --git a/.gitmodules b/.gitmodules index 6ef740e33153..8dfda44d10a0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,6 @@ [submodule "3rdparty/libbacktrace"] path = 3rdparty/libbacktrace url = https://github.com/tlc-pack/libbacktrace.git +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 000000000000..a3bcc6981d5d --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit a3bcc6981d5dad3afb212689e2c7853d1b1ee45d diff --git a/CMakeLists.txt b/CMakeLists.txt index 24f0653b3a78..f4e52e61be83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,7 @@ tvm_option(USE_MKLDNN "Build with MKLDNN" OFF) tvm_option(USE_DNNL_CODEGEN "Enable MKLDNN (DNNL) codegen" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF) tvm_option(USE_CUBLAS "Build with cuBLAS" OFF) +tvm_option(USE_CUTLASS "Build with CUTLASS" OFF) tvm_option(USE_THRUST "Build with Thrust" OFF) tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF) tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) @@ -428,6 +429,7 @@ include(cmake/modules/contrib/EthosU.cmake) include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/CODEGENC.cmake) include(cmake/modules/contrib/DNNL.cmake) +include(cmake/modules/contrib/CUTLASS.cmake) include(cmake/modules/contrib/ExampleTargetHooks.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Posit.cmake) diff --git a/LICENSE b/LICENSE index 52b2219396d2..18718f986baa 100644 --- a/LICENSE +++ b/LICENSE @@ -238,3 +238,8 @@ The Unlicense ------------- 3rdparty/rang + +BSD 3-Clause "New" or "Revised" License +--------------------------------------- + +3rdparty/cutlass \ No newline at end of file diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake new file mode 100644 index 000000000000..79555e5e26de --- /dev/null +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_CUTLASS) + file(GLOB CUTLASS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc) + list(APPEND COMPILER_SRCS ${CUTLASS_RELAY_CONTRIB_SRC}) + + message(STATUS "Build with CUTLASS") +endif() diff --git a/licenses/LICENSE.cutlass.txt b/licenses/LICENSE.cutlass.txt new file mode 100644 index 000000000000..64a49d680b1e --- /dev/null +++ b/licenses/LICENSE.cutlass.txt @@ -0,0 +1,23 @@ +Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/tvm/contrib/cutlass/__init__.py b/python/tvm/contrib/cutlass/__init__.py new file mode 100644 index 000000000000..c95d39ec5d69 --- /dev/null +++ b/python/tvm/contrib/cutlass/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BYOC support for CUTLASS.""" +from .build import tune_cutlass_kernels, build_cutlass_kernels diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py new file mode 100644 index 000000000000..1e04d0ce525f --- /dev/null +++ b/python/tvm/contrib/cutlass/build.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Driver for partitioning and building a Relay module for CUTLASS offload.""" +import tvm +from tvm import runtime, relay +from .gen_gemm import CutlassGemmProfiler + + +class GemmAnnotator(tvm.relay.ExprVisitor): + """Annotates partitioned functions with shape and dtype information.""" + + def __init__(self): + super().__init__() + self.signature = {} + + def visit_call(self, call): + op = call.op + if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs: + self.signature["op_type"] = op.attrs["Composite"] + for i, arg in enumerate(op.params): + self.signature["arg%d_shape" % i] = arg.checked_type.shape + self.signature["arg%d_dtype" % i] = arg.checked_type.dtype + self.signature["ret_shape"] = op.ret_type.shape + self.signature["ret_dtype"] = op.ret_type.dtype + + +def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"): + """Given a module partitioned for CUTLASS offloading, profile each workload to select which + kernels to emit. + + Parameters + ---------- + mod : IRModule + The Relay module with cutlass partitions. + + sm : int + An integer specifying the compute capability. For example, 75 for Turing and + 80 or 86 for Ampere. + + profile_all : bool + Whether or not profile all candidate kernels, or stop profiling after + the first applicable kernel is found. + + use_multiprocessing : bool + Whether or not compile profiler executables for different kernels in parallel. + + tmp_dir : string, optional + A temporary directory where intermediate compiled artifacts will be stored. + + Returns + ------- + mod : IRModule + The updated module annotated with cutlass profiling information. + + num_cutlass_partition : int + The number of partitioned functions created for CUTLASS. + """ + cutlass_profiler = CutlassGemmProfiler(sm, "../../../3rdparty/cutlass", tmp_dir) + num_cutlass_partition = 0 + for var in mod.get_global_vars(): + fun_name = var.name_hint + func = mod[fun_name] + annotator = GemmAnnotator() + if "cutlass" in fun_name: + num_cutlass_partition += 1 + annotator.visit(func) + # call cutlass profiler to find best settings, update attr + new_attrs = {} + new_attrs.update(annotator.signature) + for key in func.attrs.keys(): + new_attrs[key] = func.attrs[key] + # call profiler + arg0_shape = new_attrs["arg0_shape"] + arg1_shape = new_attrs["arg1_shape"] + MM = arg0_shape[0] + KK = arg0_shape[1] + NN = arg1_shape[0] + out = cutlass_profiler.profile( + MM, NN, KK, annotator.signature["ret_dtype"], profile_all, use_multiprocessing + ) + if new_attrs["op_type"] == "cutlass.dense": + new_attrs["cutlass_op_def"] = out["opdef"] + elif new_attrs["op_type"] == "cutlass.dense_bias": + new_attrs["cutlass_op_def"] = out["opdef_bias"] + elif new_attrs["op_type"] == "cutlass.dense_bias_relu": + new_attrs["cutlass_op_def"] = out["opdef_bias_relu"] + elif "cutlass.dense_bias_gelu" in new_attrs["op_type"]: + new_attrs["cutlass_op_def"] = out["opdef_bias_gelu"] + else: + raise ValueError("%s pattern is not implemented." % new_attrs["op_type"]) + new_attrs["cutlass_op_name"] = out["name"] + + print("The best kernel is " + new_attrs["cutlass_op_name"]) + if new_attrs["cutlass_op_name"].find("_tn_align") > 0: + new_attrs["lda"] = "K" + new_attrs["ldb"] = "K" + new_attrs["ldc"] = "N" + elif new_attrs["cutlass_op_name"].find("_nt_align") > 0: + new_attrs["lda"] = "M" + new_attrs["ldb"] = "N" + new_attrs["ldc"] = "N" + else: + raise ValueError("%s unsupported operation" % new_attrs["cutlass_op_name"]) + new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) + new_func = relay.Function( + func.params, + func.body, + ret_type=func.ret_type, + type_params=func.type_params, + attrs=new_attrs, + ) + mod.update_func(var, new_func) + + return mod, num_cutlass_partition + + +def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so"): + """Compile CUTLASS kernels in lib and return the runtime module ready to run. + + Parameters + ---------- + lib : GraphExecutorFactoryModule + The output from relay.build containing compiled host code and non-cutlass kernels. + + sm : int + An integer specifying the compute capability. For example, 75 for Turing and + 80 or 86 for Ampere. + + tmp_dir : string, optional + A temporary directory where intermediate compiled artifacts will be stored. + + lib_path : string, optional + The path to a shared library which will be generated as the result of the build process + + Returns + ------- + updated_lib : runtime.Module + The updated module with compiled cutlass kernels. + """ + cutlass_path = "../../../3rdparty/cutlass/include" + cutlass_util_path = "../../../3rdparty/cutlass/tools/util/include" + + kwargs = {} + kwargs["cc"] = "nvcc" + kwargs["options"] = [ + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-gencode=arch=compute_%d,code=[sm_%d,compute_%d]" % (sm, sm, sm), + "-Xcompiler=-fPIC", + "-Xcompiler=-Wconversion", + "-Xcompiler=-fno-strict-aliasing", + "-O3", + "-std=c++14", + "-I" + cutlass_path, + "-I" + cutlass_util_path, + ] + lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) + return runtime.load_module(lib_path) diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py new file mode 100644 index 000000000000..e53b3ee7b93a --- /dev/null +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import +"""Generator for CUTLASS GEMM kernels.""" +from .library import * + + +class GemmOperation: + """Describes various attributes for instantiating GEMM kernels.""" + + def __init__( + self, + arch, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, + ): + self.operation_kind = OperationKind.Gemm + self.arch = arch + self.tile_description = tile_description + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + def accumulator_type(self): + return self.tile_description.math_instruction.element_accumulator + + def short_math_name(self): + return ShortDataTypeNames[self.accumulator_type()] + + def core_name(self): + """ The basic operation kind is prefixed with a letter indicating the accumulation type. """ + inst_shape = "" + intermediate_type = "" + + if ( + self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp + or self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp + ): + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + if ( + self.tile_description.math_instruction.element_a != self.A.element + and self.tile_description.math_instruction.element_a + != self.tile_description.math_instruction.element_accumulator + ): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s%s" % ( + self.short_math_name(), + inst_shape, + intermediate_type, + "gemm", + ) + + def extended_name(self): + """ Append data types if they differ from compute type. """ + if ( + self.C.element != self.tile_description.math_instruction.element_accumulator + and self.A.element != self.tile_description.math_instruction.element_accumulator + ): + extended_name = "${element_c}_${core_name}_${element_a}" + elif ( + self.C.element == self.tile_description.math_instruction.element_accumulator + and self.A.element != self.tile_description.math_instruction.element_accumulator + ): + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = substitute_template( + extended_name, + { + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), + }, + ) + + return extended_name + + def layout_name(self): + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + def procedural_name(self): + """The full procedural name indicates architecture, extended name, tile size, + and layout. + """ + threadblock = self.tile_description.procedural_name() + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + return substitute_template( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}", + { + "opcode_class": opcode_class_name, + "extended_name": self.extended_name(), + "threadblock": threadblock, + "layout": self.layout_name(), + "alignment": "%d" % self.A.alignment, + }, + ) + + def leading_dim(self): + """ lda, ldb, ldc, according to the leading dimension. """ + if self.A.layout == LayoutType.RowMajor: + lda = "K" + elif self.A.layout == LayoutType.ColumnMajor: + lda = "M" + else: + ValueError("The layout of A is not implemented.") + + if self.B.layout == LayoutType.RowMajor: + ldb = "N" + elif self.B.layout == LayoutType.ColumnMajor: + ldb = "K" + else: + ValueError("The layout of B is not implemented.") + + if self.C.layout == LayoutType.RowMajor: + ldc = "N" + elif self.C.layout == LayoutType.ColumnMajor: + ldc = "M" + else: + ValueError("The layout of B is not implemented.") + + return substitute_template( + "int lda = ${lda_val};\n\tint ldb = ${ldb_val};\n\tint ldc = ${ldc_val};\n", + { + "lda_val": lda, + "ldb_val": ldb, + "ldc_val": ldc, + }, + ) + + +class EmitGemmInstance: + """ Responsible for emitting a CUTLASS template definition.""" + + def __init__(self): + self.epilogue_default = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >""" + self.epilogue_no_beta_scaling = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >""" + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::Gemm< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue}, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + + def emit(self, operation, no_beta_scaling=False): + """Instantiate a GEMM kernel from given `operation`.""" + warp_shape = [ + operation.tile_description.threadblock_shape[idx] + // operation.tile_description.warp_count[idx] + for idx in range(3) + ] + epilogue_vector_length = ( + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) + // DataTypeSize[operation.C.element] + ) + residual = "" + complex_transform_tag = "cutlass::ComplexTransform::kNone" + values = { + "operation_name": operation.procedural_name(), + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "transform_a": complex_transform_tag, + "transform_b": complex_transform_tag, + "math_operation": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + "residual": residual, + } + + gemm_template = substitute_template( + self.gemm_template, + { + "epilogue": self.epilogue_no_beta_scaling + if no_beta_scaling + else self.epilogue_default + }, + ) + return substitute_template(gemm_template, values) diff --git a/python/tvm/contrib/cutlass/gemm_profiler.py b/python/tvm/contrib/cutlass/gemm_profiler.py new file mode 100644 index 000000000000..13679cd05c42 --- /dev/null +++ b/python/tvm/contrib/cutlass/gemm_profiler.py @@ -0,0 +1,196 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel, invalid-name +"""Instantiate a C++ source for profiling CUTLASS kernels.""" + + +class GemmProfilerEmitter(object): + """Emit a C++ source for profiling CUTLASS kernels.""" + + def __init__(self): + from jinja2 import Template + + self.template = Template( + """ +#include +#include +#include +#include + +#include "cuda_runtime.h" +#include "cutlass/gemm/device/gemm.h" + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \\ + << std::endl; \\ + exit(EXIT_FAILURE); \\ + } \\ + } + +#define CUDA_CHECK(status) \\ + { \\ + cudaError_t error = status; \\ + if (error != cudaSuccess) { \\ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \\ + << " at line: " << __LINE__ << std::endl; \\ + exit(EXIT_FAILURE); \\ + } \\ + } + +template +cudaError_t CutlassGemmRCR( + int M, + int N, + int K, + DTypeC alpha, + DTypeA const *A, + int lda, + DTypeB const *B, + int ldb, + DTypeC beta, + DTypeC *C, + int ldc) { + using namespace std::chrono; + {{OperatorDef}} + Operation_{{OperatorName}} gemm_operator; + Operation_{{OperatorName}}::Arguments args({M, N, K}, + {A, lda}, + {B, ldb}, + {C, ldc}, + {C, ldc}, + {alpha, beta}); + cutlass::Status status = gemm_operator(args); + CUTLASS_CHECK(status) + + high_resolution_clock::time_point t1 = high_resolution_clock::now(); + for (int i = 0; i < 100; ++i) { + status = gemm_operator(args); + } + cudaDeviceSynchronize(); + high_resolution_clock::time_point t2 = high_resolution_clock::now(); + duration time_span = duration_cast>(t2 - t1); + std::cout << time_span.count() << std::endl; + return cudaSuccess; +} + + +template +cudaError_t AllocateMatrix(DType **matrix, int ldm, int rows, int columns, int seed = 0) { + cudaError_t result; + + size_t sizeof_matrix = sizeof(DType) * rows * columns; + + // Allocate device memory. + result = cudaMalloc(reinterpret_cast(matrix), sizeof_matrix); + + if (result != cudaSuccess) { + std::cerr << "Failed to allocate matrix: " + << cudaGetErrorString(result) << std::endl; + return result; + } + + // Clear the allocation. + result = cudaMemset(*matrix, 0, sizeof_matrix); + + if (result != cudaSuccess) { + std::cerr << "Failed to clear matrix device memory: " + << cudaGetErrorString(result) << std::endl; + return result; + } + + if (result != cudaSuccess) { + std::cerr << "Failed to initialize matrix: " + << cudaGetErrorString(result) << std::endl; + return result; + } + + return result; +} + +template +cudaError_t TestCutlassGemm(int M, int N, int K, DTypeC alpha, DTypeC beta) { + cudaError_t result; + + {{LeadingDim}} + // size_t sizeof_C = sizeof(DTypeC) * ldc * N; + DTypeA *A; + DTypeB *B; + DTypeC *C_cutlass; + result = AllocateMatrix(&A, lda, M, K, 0); + if (result != cudaSuccess) { + return result; + } + result = AllocateMatrix(&B, ldb, K, N, 17); + if (result != cudaSuccess) { + cudaFree(A); + return result; + } + result = AllocateMatrix(&C_cutlass, ldc, M, N, 101); + if (result != cudaSuccess) { + cudaFree(A); + cudaFree(B); + return result; + } + result = CutlassGemmRCR(M, N, K, alpha, A, lda, B, ldb, + beta, C_cutlass, ldc); + if (result != cudaSuccess) { + std::cerr << "CUTLASS GEMM kernel failed: " + << cudaGetErrorString(result) << std::endl; + cudaFree(C_cutlass); + cudaFree(B); + cudaFree(A); + + return result; + } + cudaFree(C_cutlass); + cudaFree(B); + cudaFree(A); + return cudaSuccess; +} + +int main(int argc, const char *arg[]) { + int problem[3] = { 4096, 4096, 4096 }; + for (int i = 1; i < argc && i < 4; ++i) { + std::stringstream ss(arg[i]); + ss >> problem[i - 1]; + } + float scalars[2] = { 1, 0 }; + cudaError_t result = TestCutlassGemm< {{DTypeA}}, {{DTypeB}}, {{DTypeC}}>( + problem[0], // GEMM M dimension + problem[1], // GEMM N dimension + problem[2], // GEMM K dimension + static_cast<{{DTypeC}}>(scalars[0]), // alpha + static_cast<{{DTypeC}}>(scalars[1]) // beta + ); + return result == cudaSuccess ? 0 : -1; +} +""" + ) + + def emit(self, op_name, op_def, dtype_a, dtype_b, dtype_c, ld): + src = self.template.render( + OperatorName=op_name, + OperatorDef=op_def, + DTypeA=dtype_a, + DTypeB=dtype_b, + DTypeC=dtype_c, + LeadingDim=ld, + ) + return src diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py new file mode 100644 index 000000000000..803a90a1c54f --- /dev/null +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -0,0 +1,355 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Kernel generator and profiler for CUTLASS.""" +import os +import re +import tempfile +import subprocess +import multiprocessing +from .gemm_operation import GemmOperation, EmitGemmInstance +from .gemm_profiler import GemmProfilerEmitter +from .library import ( + EpilogueFunctor, + SwizzlingFunctor, + TensorDescription, + DataTypeTag, + LayoutType, + MathInstruction, + DataType, + OpcodeClass, + MathOperation, + TileDescription, +) + + +def create_gemm_operator( + layouts, + tile_descriptions, + data_type, + alignment_constraints, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, +): + """Exhaustively instantiate all kernels from a given configuration.""" + ret = [] + kernel_emitter = EmitGemmInstance() + profiler_emitter = GemmProfilerEmitter() + + element_a, element_b, element_c, element_epilogue = data_type + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + op_entry = {} + op = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor, + swizzling_functor, + ) + op_bias = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationBias, + swizzling_functor, + ) + op_bias_relu = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationRelu, + swizzling_functor, + ) + op_bias_gelu = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationGelu, + swizzling_functor, + ) + + kernel_emitter = EmitGemmInstance() + op_entry["op"] = op + op_entry["name"] = op.procedural_name() + op_entry["opdef"] = kernel_emitter.emit(op) + op_entry["opdef_bias"] = kernel_emitter.emit(op_bias, no_beta_scaling=True) + op_entry["opdef_bias_relu"] = kernel_emitter.emit( + op_bias_relu, no_beta_scaling=True + ) + op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu) + op_entry["src"] = profiler_emitter.emit( + op.procedural_name(), + op_entry["opdef"], + DataTypeTag[element_a], + DataTypeTag[element_b], + DataTypeTag[element_c], + op.leading_dim(), + ) + op_entry["runtime"] = 9999999 + ret.append(op_entry) + return ret + + +def generate_tensor_op_common(math_instructions, alignment_constraints, get_tile_descriptions): + """Common kernel generator to be used by archtecture specific generators.""" + ops = [] + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + for math_inst in math_instructions: + tile_descriptions = get_tile_descriptions(math_inst) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + out = create_gemm_operator(layouts, tile_descriptions, data_type, alignment_constraints) + + ops.extend(out) + + return ops + + +def generate_sm75_tensor_op_1688(out_dtype): + """Generate GEMM kernels for Turing.""" + assert out_dtype in ["float32", "float16"] + math_instructions = { + "float32": [ + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + "float16": [ + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + }[out_dtype] + + alignment_constraints = [8, 4, 2, 1] + + def get_tile_descriptions(math_inst): + min_cc = 75 + max_cc = 1024 + return [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc), + ] + + return generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions + ) + + +def generate_sm80_tensor_op_16816(out_dtype): + """Generate GEMM kernels for Ampere.""" + assert out_dtype in ["float32", "float16"] + math_instructions = { + "float32": [ + MathInstruction( + [16, 8, 16], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + "float16": [ + MathInstruction( + [16, 8, 16], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + }[out_dtype] + + alignment_constraints = [8, 4, 2] + + def get_tile_descriptions(math_inst): + min_cc = 80 + max_cc = 1024 + max_cc_smem_limited = 80 + return [ + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + return generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions + ) + + +GENERATOR_FUNC_TABLE = { + 75: generate_sm75_tensor_op_1688, + 80: generate_sm80_tensor_op_16816, +} + + +class ProfilerEngine(object): + """Compile and run a given profiler executable.""" + + def __init__(self, cuda_arch, cutlass_path, binary_prefix): + self.cuda_arch = cuda_arch + self.binary_prefix = binary_prefix + self.cutlass = cutlass_path + self.cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11".format( + cutlass=cutlass_path + ) + self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" + self.cflags += " -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]".format( + arch=cuda_arch + ) + self.cflags += " -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing" + self.cmd = "nvcc {cflags} {src} -o {output}" + + def _compile(self, op): + os.makedirs(self.binary_prefix, exist_ok=True) + opath = os.path.join(self.binary_prefix, op["name"]) + if os.path.exists(opath): + return + fi = tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu") + fi.write(op["src"]) + fi.close() + cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath) + os.system(cmd) + os.unlink(fi.name) + + def compile_all(self, ops, use_multiprocessing=False): + """Compile all profiler executables.""" + if use_multiprocessing: + pool = multiprocessing.Pool(multiprocessing.cpu_count()) + pool.map(self._compile, ops) + else: + for op in ops: + self._compile(op) + + def evaluate(self, op_name, args): + """Run the profiler executable corresponding to op_name with args.""" + opath = os.path.join(self.binary_prefix, op_name) + cmd = [opath] + if args is not None: + cmd.append(str(args[0])) + cmd.append(str(args[1])) + cmd.append(str(args[2])) + if len(args) > 3: + cmd.append(str(args[3])) + try: + sp = subprocess.run(cmd, capture_output=True, check=True) + rt = float(sp.stdout) + print(op_name, rt) + except subprocess.CalledProcessError: + rt = -1 + return rt + + +class CutlassGemmProfiler(object): + """Profile all candidate kernels and select the best one.""" + + def __init__(self, sm, cutlass_path, binary_path): + assert sm in GENERATOR_FUNC_TABLE, "sm%d not supported yet." % sm + self.engine = ProfilerEngine(sm, cutlass_path, binary_path) + self.sm = sm + + def check_align(self, op_name, M): + """Filter out kernels that cannot be supported.""" + aligns = re.findall(r"align[1|2|4|8]", op_name) + assert len(aligns) == 1 + align = int(aligns[0][-1]) + if M % align != 0: + return False + return True + + def profile(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False): + """Profile and select the best kernel from candidate kernels. + If profile_all is False, return immediately after the first applicable kernel is found. + If use_multiprocessing is True, compile all profiler executables in parallel. + """ + ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype) + ops = list(filter(lambda op: self.check_align(op["name"], M), ops)) + + for op in ops: + op["runtime"] = -1 + + self.engine.compile_all(ops, use_multiprocessing) + + for op in ops: + out = self.engine.evaluate(op["name"], [M, N, K]) + op["runtime"] = out + if out > 0 and profile_all is False: + break + + valid_ops = filter(lambda op: op["runtime"] > 0, ops) + output = sorted(valid_ops, key=lambda i: i["runtime"]) + return output[0] diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py new file mode 100644 index 000000000000..7d544293901a --- /dev/null +++ b/python/tvm/contrib/cutlass/library.py @@ -0,0 +1,219 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Various type definitions to help instantiate CUTLASS kernels.""" +import re +import enum +from enum import auto as enum_auto + + +class GeneratorTarget(enum.Enum): + Library = enum_auto() + + +class DataType(enum.Enum): + f16 = enum_auto() + f32 = enum_auto() + + +ShortDataTypeNames = { + DataType.f16: "h", + DataType.f32: "s", +} + + +DataTypeNames = { + DataType.f16: "f16", + DataType.f32: "f32", +} + +DataTypeTag = { + DataType.f16: "cutlass::half_t", + DataType.f32: "float", +} + +DataTypeSize = { + DataType.f16: 16, + DataType.f32: 32, +} + + +class MathOperation(enum.Enum): + multiply_add = enum_auto() + + +MathOperationTag = { + MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", +} + + +class LayoutType(enum.Enum): + ColumnMajor = enum_auto() + RowMajor = enum_auto() + + +LayoutTag = { + LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor", + LayoutType.RowMajor: "cutlass::layout::RowMajor", +} + + +TransposedLayout = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor, +} + + +ShortLayoutTypeNames = { + LayoutType.ColumnMajor: "n", + LayoutType.RowMajor: "t", +} + + +class OpcodeClass(enum.Enum): + Simt = enum_auto() + TensorOp = enum_auto() + WmmaTensorOp = enum_auto() + + +OpcodeClassNames = { + OpcodeClass.Simt: "simt", + OpcodeClass.TensorOp: "tensorop", + OpcodeClass.WmmaTensorOp: "wmma_tensorop", +} + +OpcodeClassTag = { + OpcodeClass.Simt: "cutlass::arch::OpClassSimt", + OpcodeClass.TensorOp: "cutlass::arch::OpClassTensorOp", + OpcodeClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp", +} + + +class OperationKind(enum.Enum): + Gemm = enum_auto() + + +OperationKindNames = { + OperationKind.Gemm: "gemm", +} + + +class Target(enum.Enum): + library = enum_auto() + + +def substitute_template(template, values): + """Instantiate a kernel template using `values`.""" + text = template + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext + return text + + +class GemmKind(enum.Enum): + Gemm = enum_auto() + + +GemmKindNames = { + GemmKind.Gemm: "gemm", +} + + +class EpilogueFunctor(enum.Enum): + LinearCombination = enum_auto() + LinearCombinationRelu = enum_auto() + LinearCombinationBias = enum_auto() + LinearCombinationGelu = enum_auto() + + +EpilogueFunctorTag = { + EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination", + EpilogueFunctor.LinearCombinationRelu: "cutlass::epilogue::thread::LinearCombinationRelu", + EpilogueFunctor.LinearCombinationBias: "cutlass::epilogue::thread::LinearCombination", + EpilogueFunctor.LinearCombinationGelu: "cutlass::epilogue::thread::LinearCombinationGELU", +} + + +class SwizzlingFunctor(enum.Enum): + Identity1 = enum_auto() + Identity2 = enum_auto() + Identity4 = enum_auto() + Identity8 = enum_auto() + + +SwizzlingFunctorTag = { + SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>", + SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>", + SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", + SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>", +} + + +class MathInstruction: + """Describe characteristics of a math instruction.""" + + def __init__( + self, + instruction_shape, + element_a, + element_b, + element_accumulator, + opcode_class, + math_operation=MathOperation.multiply_add, + ): + self.instruction_shape = instruction_shape + self.element_a = element_a + self.element_b = element_b + self.element_accumulator = element_accumulator + self.opcode_class = opcode_class + self.math_operation = math_operation + + +class TileDescription: + """Describe characteristics of a GEMM tile.""" + + def __init__( + self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute + ): + self.threadblock_shape = threadblock_shape + self.stages = stages + self.warp_count = warp_count + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + + def procedural_name(self): + return "%dx%d_%dx%d" % ( + self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.stages, + ) + + +class TensorDescription: + def __init__(self, element, layout, alignment=1): + self.element = element + self.layout = layout + self.alignment = alignment diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 30c2db0ddf0b..1dd6da6c2747 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -24,3 +24,4 @@ from .coreml import * from .ethosn import * from .tensorrt import * +from .cutlass import * diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py new file mode 100644 index 000000000000..631089ce766d --- /dev/null +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Patterns supported CUTLASS.""" +from tvm.relay import transform +from ...dataflow_pattern import wildcard, is_op, is_constant + + +def make_gelu_pattern(bias_out, out_dtype="float16"): + mul = is_op("multiply")(bias_out, is_constant()) + if out_dtype == "float16": + erf = is_op("cast")(is_op("erf")(is_op("cast")(mul))) + else: + erf = is_op("erf")(mul) + mul_half = is_op("multiply")(erf, is_constant()) + add = is_op("add")(mul_half, is_constant()) + return is_op("multiply")(add, bias_out) + + +def make_gemm_pattern(with_bias=True, with_act=None, out_dtype="float16"): + """Create a pattern for dense op followed by activations.""" + data = wildcard() + weight = wildcard() + bias = wildcard() + gemm = is_op("nn.dense")(data, weight) + if with_bias: + add_or_bias_add = is_op("add") | is_op("nn.bias_add") + gemm_out = add_or_bias_add(gemm, bias) + else: + gemm_out = gemm + + if with_act is None: + return gemm_out + if isinstance(with_act, str) and with_act == "relu": + return is_op("nn.relu")(gemm_out) + + assert isinstance(with_act, str) and with_act == "gelu" + return make_gelu_pattern(gemm_out, out_dtype) + + +def partition_for_cutlass(mod): + """Partition the input module into CUTLASS-supported subgraphs.""" + dense_pat = ("cutlass.dense", make_gemm_pattern(False, None)) + dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None)) + dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu")) + dense_bias_gelu_fp16_pat = ("cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu")) + dense_bias_gelu_fp32_pat = ( + "cutlass.dense_bias_gelu_fp32", + make_gemm_pattern(True, "gelu", out_dtype="float32"), + ) + cutlass_patterns = [ + dense_bias_gelu_fp16_pat, + dense_bias_gelu_fp32_pat, + dense_bias_relu_pat, + dense_bias_pat, + dense_pat, + ] + mod = transform.MergeComposite(cutlass_patterns)(mod) + mod = transform.AnnotateTarget(["cutlass"])(mod) + mod = transform.PartitionGraph()(mod) + return mod diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 0d575b3ec498..428cdd3d431a 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -44,6 +44,12 @@ struct Output { bool need_copy; }; +struct GenerateBodyOutput { + std::string decl; + std::vector buffers; + std::vector outputs; +}; + class CSourceModuleCodegenBase { public: CSourceModuleCodegenBase() = default; diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc new file mode 100644 index 000000000000..913322ca06da --- /dev/null +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/cutlass/codegen.cc + * \brief Implementation of CUTLASS codegen. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../utils.h" +#include "../codegen_c/codegen_c.h" + +namespace tvm { +namespace relay { +namespace contrib { + +using namespace backend; +using Str2StrMap = std::unordered_map; + +static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, {"float32", "float"}}; + +Str2StrMap DenseArgs(const Map& attrs) { + Str2StrMap args; + auto arg0_dtype = std::string(attrs["arg0_dtype"].as()->data); + auto arg1_dtype = std::string(attrs["arg1_dtype"].as()->data); + auto ret_dtype = std::string(attrs["ret_dtype"].as()->data); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + args["ElementInputA"] = dtype_map.at(arg0_dtype); + args["ElementInputB"] = dtype_map.at(arg1_dtype); + args["ElementOutput"] = dtype_map.at(ret_dtype); + args["M"] = std::to_string(arg0_shape->at(0).as()->value); + args["K"] = std::to_string(arg0_shape->at(1).as()->value); + args["N"] = std::to_string(arg1_shape->at(0).as()->value); + args["op_def"] = std::string(attrs["cutlass_op_def"].as()->data); + args["op_name"] = std::string(attrs["cutlass_op_name"].as()->data); + args["op_type"] = std::string(attrs["op_type"].as()->data); + args["lda"] = std::string(attrs["lda"].as()->data); + args["ldb"] = std::string(attrs["ldb"].as()->data); + args["ldc"] = std::string(attrs["ldc"].as()->data); + return args; +} + +inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int indent = 2) { + for (int i = 0; i < indent; ++i) { + os << " "; + } + os << stmt; +} + +std::string DenseOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args) { + bool has_bias = false; + bool is_gelu = + attrs.at("op_type").find("cutlass.dense_bias_gelu") != std::string::npos; // fp32 or fp16 + if (attrs.at("op_type") == "cutlass.dense_bias" || + attrs.at("op_type") == "cutlass.dense_bias_relu" || is_gelu) { + has_bias = true; + } + std::ostringstream gemm_decl; + CutlassPrint(gemm_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); + CutlassPrint(gemm_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n"); + CutlassPrint(gemm_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n"); + CutlassPrint(gemm_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n"); + CutlassPrint(gemm_decl, attrs.at("op_def")); + CutlassPrint(gemm_decl, "using Gemm = Operation_" + attrs.at("op_name") + ";\n"); + /// Gemm Call + + // Create TensorRef + CutlassPrint(gemm_decl, "int M = " + attrs.at("M") + ";\n"); + CutlassPrint(gemm_decl, "int N = " + attrs.at("N") + ";\n"); + CutlassPrint(gemm_decl, "int K = " + attrs.at("K") + ";\n"); + CutlassPrint(gemm_decl, "cutlass::gemm::GemmCoord problem_size(M, N, K);\n"); + // Initialize alpha for dot product computation + CutlassPrint(gemm_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); + if (is_gelu) { + // GeLU epilogue does not compile with NoBetaScaling, so we explicitly specify the scale. + CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n"); + } else { + CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); + } + + // Split K dimension into 1 partitions + CutlassPrint(gemm_decl, "int split_k_slices = 1;\n"); + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + ICHECK(func_args.size() >= 2); + CutlassPrint(gemm_decl, "void* ptr_a = (void*)(" + func_args[0] + ");\n"); + CutlassPrint(gemm_decl, "void* ptr_b = (void*)(" + func_args[1] + ");\n"); + if (has_bias) { + ICHECK(func_args.size() >= 3); + CutlassPrint(gemm_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + ");\n"); + } + CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0);\n"); + + CutlassPrint(gemm_decl, "typename Gemm::Arguments arguments{\n"); + CutlassPrint(gemm_decl, " problem_size,\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); + if (has_bias) { + CutlassPrint(gemm_decl, " {static_cast(ptr_c_bias), 0},\n"); + } else { + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + } + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + if (has_bias && !is_gelu) { + CutlassPrint(gemm_decl, " {alpha},\n"); + } else { + // For GeLU, we explicitly specify the scale. + CutlassPrint(gemm_decl, " {alpha, beta},\n"); + } + CutlassPrint(gemm_decl, " split_k_slices};\n"); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + CutlassPrint(gemm_decl, "size_t workspace_size = Gemm::get_workspace_size(arguments);\n"); + // Allocate workspace memory + CutlassPrint(gemm_decl, + "cutlass::device_memory::allocation workspace(workspace_size);\n"); + // Instantiate CUTLASS kernel depending on template + CutlassPrint(gemm_decl, "Gemm gemm_op;\n"); + // Check the problem size is supported or not + CutlassPrint(gemm_decl, "cutlass::Status status = gemm_op.can_implement(arguments);\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + // Initialize CUTLASS kernel with arguments and workspace pointer + CutlassPrint(gemm_decl, "status = gemm_op.initialize(arguments, workspace.get());\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + // Launch initialized CUTLASS kernel + CutlassPrint(gemm_decl, "status = gemm_op();\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + return gemm_decl.str(); +} + +class CodegenCutlass : public MemoizedExprTranslator>, public CodegenCBase { + public: + CodegenCutlass(const std::string& id, const Map& attrs) { + this->ext_func_id_ = id; + this->attrs_ = attrs; + } + + std::vector VisitExprDefault_(const Object* op) final { + LOG(FATAL) << "Cutlass codegen doesn't support: " << op->GetTypeKey(); + return {}; + } + + std::vector VisitExpr_(const VarNode* node) final { + ext_func_args_.push_back(GetRef(node)); + Output output; + output.name = node->name_hint(); + return {output}; + } + + std::vector VisitExpr_(const CallNode* call) final { + const auto* func = call->op.as(); + ICHECK(func) << "Only composite function is supported for CUTLASS."; + GenerateBodyOutput ret = GenerateCompositeFunctionCall(func, call); + ext_func_body_.push_back(ret.decl); + return ret.outputs; + } + + std::string JIT(const std::vector& out) { + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); + } + + private: + std::vector GetArgumentNames(const CallNode* call) { + std::vector arg_names; + for (size_t i = 0; i < call->args.size(); ++i) { + auto res = VisitExpr(call->args[i]); + for (const auto& out : res) { + arg_names.push_back(out.name); + } + } + return arg_names; + } + + GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee, + const CallNode* caller) { + const auto pattern_name = callee->GetAttr(attr::kComposite); + ICHECK(pattern_name.defined()) << "Only functions with composite attribute are supported."; + + if (pattern_name == "cutlass.dense") { + const auto* dense_call = GetRootCall(callee->body.as(), 0, {"nn.dense"}); + return GenerateBody(dense_call, "cutlass_dense", GetArgumentNames(caller), + DenseArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.dense_bias") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->op.as()->name; + const auto* dense_call = + GetRootCall(callee->body.as(), 1, {"nn.dense", add_or_bias_add}); + return GenerateBody(dense_call, "cutlass_dense_bias", GetArgumentNames(caller), + DenseArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.dense_bias_relu") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; + const auto* dense_call = + GetRootCall(callee->body.as(), 2, {"nn.dense", add_or_bias_add, "nn.relu"}); + return GenerateBody(dense_call, "cutlass_dense_bias_relu", GetArgumentNames(caller), + DenseArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.dense_bias_gelu_fp16") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->args[1].as()->op.as()->name; + const auto* dense_call = GetRootCall(callee->body.as(), 8, + {"nn.dense", add_or_bias_add, "multiply", "cast", "erf", + "cast", "multiply", "add", "multiply"}); + return GenerateBody(dense_call, "cutlass_dense_bias_gelu", GetArgumentNames(caller), + DenseArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.dense_bias_gelu_fp32") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->args[1].as()->op.as()->name; + const auto* dense_call = GetRootCall( + callee->body.as(), 6, + {"nn.dense", add_or_bias_add, "multiply", "erf", "multiply", "add", "multiply"}); + return GenerateBody(dense_call, "cutlass_dense_bias_gelu", GetArgumentNames(caller), + DenseArgs(std::ref(attrs_))); + } + LOG(FATAL) << "Unknown composite function: " << pattern_name; + return {}; + } + + GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name, + const std::vector& func_args, + const Str2StrMap& attribute_args) { + // Make function call with input buffers when visiting arguements + ICHECK_GT(func_args.size(), 0); + std::ostringstream decl_stream; + decl_stream << "(" << func_args[0]; + for (size_t i = 1; i < func_args.size(); ++i) { + decl_stream << ", " << func_args[i]; + } + // Analyze the output buffers + std::vector out_types; + if (root_call->checked_type()->IsInstance()) { + auto type_node = root_call->checked_type().as(); + for (auto field : type_node->fields) { + ICHECK(field->IsInstance()); + out_types.push_back(field); + } + } else if (root_call->checked_type()->IsInstance()) { + ICHECK(root_call->checked_type()->IsInstance()); + out_types.push_back(root_call->checked_type()); + } else { + LOG(FATAL) << "Unrecognized type node: " << AsText(root_call->checked_type(), false); + } + GenerateBodyOutput ret; + for (const auto& out_type : out_types) { + const std::string out = "out" + std::to_string(buf_idx_++); + decl_stream << ", " << out; + Output output; + output.name = out; + output.dtype = GetDtypeString(out_type.as()); + output.need_copy = false; + ret.outputs.push_back(output); + } + decl_stream << ");"; + if (func_name == "cutlass_dense" || func_name == "cutlass_dense_bias" || + func_name == "cutlass_dense_bias_relu" || func_name == "cutlass_dense_bias_gelu") { + ret.decl = DenseOp(ext_func_id_, attribute_args, func_args); + } + return ret; + } + /*! \brief The id of the external cutlass ext_func. */ + std::string ext_func_id_{""}; + /*! \brief The attrs of the external cutlass ext_func. */ + Map attrs_; + /*! + * \brief The index to track the output buffer. Each kernel will redirect the + * output to a buffer that may be consumed by other kernels. + */ + int buf_idx_{0}; + /*! \brief The arguments used by a wrapped function that calls CUTLASS kernels. */ + Array ext_func_args_; + /*! \brief Statement of the function that will be compiled using CUTLASS kernels. */ + std::vector ext_func_body_; + /*! \brief The array declared to store the constant values. */ + std::string const_array_name_; + /*! \brief The declaration of intermediate buffers. */ + std::vector buf_decl_; +}; // class CodegenCutlass + +class CutlassModuleCodegen : public CSourceModuleCodegenBase { + public: + std::pair> GenCutlassFunc(const Function& func) { + ICHECK(func.defined()) << "Input error: expect a Relay function."; + // Record the external symbol for runtime lookup. + auto sid = GetExtSymbol(func); + const auto* attrs = func->attrs.as(); + ICHECK(attrs != nullptr); + const auto dict = attrs->dict; + CodegenCutlass builder(sid, dict); + auto out = builder.VisitExpr(func->body); + code_stream_ << builder.JIT(out); + return {sid, {}}; + } + + runtime::Module CreateCSourceModule(const ObjectRef& ref) override { + // create header + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + // cutlass header + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + + ICHECK(ref->IsInstance()); + auto res = GenCutlassFunc(Downcast(ref)); + std::string code = code_stream_.str(); + String sym = std::get<0>(res); + Array variables = std::get<1>(res); + // Create a CSource module + const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); + ICHECK(pf != nullptr) << "Cannot find CSource module to create the external runtime module"; + return (*pf)(code, "cu", Array{sym}, variables); + } + + private: + /*! \brief The code stream that will be compiled by NVCC */ + std::ostringstream code_stream_; +}; // CutlassModuleCodegen + +/*! + * \brief The external cutlass compiler/codegen tool. It takes a Relay + * expression/module and compile it into a runtime module. + */ +runtime::Module CutlassCompiler(const ObjectRef& ref) { + CutlassModuleCodegen cutlass; + return cutlass.CreateCSourceModule(ref); +} + +TVM_REGISTER_GLOBAL("relay.ext.cutlass").set_body_typed(CutlassCompiler); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index ae58c2f08e8c..fa1dbc66d8a7 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -231,12 +231,6 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C } private: - struct GenerateBodyOutput { - std::string decl; - std::vector buffers; - std::vector outputs; - }; - std::vector GetArgumentNames(const CallNode* call) { std::vector arg_names; for (size_t i = 0; i < call->args.size(); ++i) { diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py new file mode 100644 index 000000000000..a3ddb06345df --- /dev/null +++ b/tests/python/contrib/test_cutlass.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import math +import pytest +import tvm +from tvm import relay +import numpy as np +from tvm.relay.op.contrib.cutlass import partition_for_cutlass +from tvm.contrib.cutlass import tune_cutlass_kernels, build_cutlass_kernels + + +def get_ref_rt_mod(mod, params): + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target="cuda", params=params) + dev = tvm.device("cuda", 0) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + return rt_mod, dev + + +def get_output(rt_mod, x): + rt_mod.set_input("data", x) + rt_mod.run() + return rt_mod.get_output(0).asnumpy() + + +def get_dense(M, N, K, out_dtype="float16"): + data = relay.var("data", shape=(M, K), dtype="float16") + weight = relay.var("weight", shape=(N, K), dtype="float16") + return relay.nn.dense(data, weight, out_dtype=out_dtype) + + +def get_dense_bias(M, N, K, out_dtype="float16"): + dense = get_dense(M, N, K, out_dtype=out_dtype) + bias = relay.var("bias", shape=(N,), dtype=out_dtype) + return relay.nn.bias_add(dense, bias) + + +def get_dense_bias_relu(M, N, K, out_dtype="float16"): + return relay.nn.relu(get_dense_bias(M, N, K, out_dtype="float16")) + + +def get_dense_bias_gelu(M, N, K, out_dtype="float16"): + bias_add = get_dense_bias(M, N, K, out_dtype) + mul = bias_add * relay.const((1.0 / math.sqrt(2.0)), dtype=out_dtype) + if out_dtype == "float16": + erf = relay.cast(relay.op.erf(relay.cast(mul, "float32")), "float16") + else: + erf = relay.op.erf(mul) + mul_half = erf * relay.const(0.5, dtype=out_dtype) + add = mul_half + relay.const(0.5, dtype=out_dtype) + return add * bias_add + + +def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"): + mod = partition_for_cutlass(mod) + mod, num_cutlass_partition = tune_cutlass_kernels( + mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir + ) + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target="cuda", params=params) + lib = build_cutlass_kernels(lib, sm, tmp_dir, lib_path) + dev = tvm.device("cuda", 0) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + return rt_mod, dev, num_cutlass_partition + + +def verify(func, M, N, K, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False): + if not tvm.get_global_func("relay.ext.cutlass", True): + return + mod = tvm.IRModule.from_expr(func) + typ = relay.transform.InferType()(mod) + out_dtype = typ["main"].body.checked_type.dtype + np_data = np.random.uniform(-1, 1, (M, K)).astype("float16") + np_weight = np.random.uniform(-1, 1, (N, K)).astype("float16") + np_bias = np.random.uniform(-1, 1, (N,)).astype(out_dtype) + + params = {"weight": np_weight, "bias": np_bias} + + rt_mod_ref, dev = get_ref_rt_mod(mod, params) + rt_mod, dev, num_partition = profile_and_build(mod, params, sm) + assert num_partition > 0 + + x = tvm.nd.array(np_data, device=dev) + + out = get_output(rt_mod, x) + ref_out = get_output(rt_mod_ref, x) + + np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol) + + if run_benchmark: + print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600)) + print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600)) + + +M = 1820 +N = 768 +K = 768 + + +def test_dense(): + verify(get_dense(M, N, K), M, N, K) + verify(get_dense(M, N, K, out_dtype="float32"), M, N, K) + + +def test_dense_bias(): + verify(get_dense_bias(M, N, K), M, N, K) + verify(get_dense_bias(M, N, K, out_dtype="float32"), M, N, K) + + +def test_dense_bias_relu(): + verify(get_dense_bias_relu(M, N, K), M, N, K) + verify(get_dense_bias_relu(M, N, K, out_dtype="float32"), M, N, K) + + +def test_dense_bias_gelu(): + verify(get_dense_bias_gelu(M, N, K), M, N, K, atol=1e-3, rtol=1e-3) + verify(get_dense_bias_gelu(M, N, K, out_dtype="float32"), M, N, K, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__])