Skip to content

Commit bd0f318

Browse files
masahivinx13
authored andcommitted
[BYOC] Add CUTLASS backend (apache#380)
* Add CUTLASS backend * fix * Wrap and annotate in FuseOpsByPattern optionally * fix * black * add test for FuseOpsByPattern change * black * ignore 3rd party in pylint * fix test * another unused var warning * Update include/tvm/relax/transform.h Co-authored-by: Wuwei Lin <vincentl13x@gmail.com> * fix for v3 * fix for int8 test in relay byoc * more fix for cutlass update * fix residual block fusion offload * fix test --------- Co-authored-by: Wuwei Lin <vincentl13x@gmail.com>
1 parent cdc73ec commit bd0f318

File tree

17 files changed

+1021
-199
lines changed

17 files changed

+1021
-199
lines changed

cmake/modules/contrib/CUTLASS.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
# under the License.
1717

1818
if(USE_CUDA AND USE_CUTLASS)
19-
tvm_file_glob(GLOB CUTLASS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc)
20-
list(APPEND COMPILER_SRCS ${CUTLASS_RELAY_CONTRIB_SRC})
19+
tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc)
20+
list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC})
2121

2222
message(STATUS "Build with CUTLASS")
2323
endif()

gallery/how_to/work_with_relay/using_pipeline_executor.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,7 @@
2929
from tvm import relay
3030
from tvm.relay import testing
3131
import tvm.testing
32-
from tvm.contrib.cutlass import (
33-
has_cutlass,
34-
num_cutlass_partitions,
35-
finalize_modules,
36-
finalize_modules_vm,
37-
)
32+
from tvm.contrib.cutlass import finalize_modules
3833

3934
img_size = 8
4035
#######################################################################
@@ -50,7 +45,6 @@ def get_network():
5045
"dweight", relay.TensorType((batch_size, 16 * img_size * img_size), "float16")
5146
)
5247
weight = relay.var("weight")
53-
second_weight = relay.var("second_weight")
5448
bn_gamma = relay.var("bn_gamma")
5549
bn_beta = relay.var("bn_beta")
5650
bn_mmean = relay.var("bn_mean")

include/tvm/relax/transform.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,21 +183,22 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
183183
* of a fused function after successful matching.
184184
* \param patterns The patterns to detect. The order of the patterns determines the order
185185
* of priority in which they are matched. Higher-priority patterns should come earlier in the list.
186+
* \param annotate_codegen If true, wrap each created composite function with another function,
187+
* whose body consists only of a call to the composite function, and annotate the outer function
188+
* with kCodegen and kGlobalSymbol attributes. The kCodegen attribute is set as the prefix of the
189+
* corresponding pattern name. For example, "dnnl" if the pattern name is "dnnl.conv2d_relu".
190+
* This must be True if the created composite functions are intended to be offloaded to
191+
* an external backend without using the MergeCompositeFunctions pass.
186192
* \return The Pass.
187193
*/
188194
TVM_DLL Pass FuseOpsByPattern(const tvm::Array<runtime::String>& pattern_names,
189-
const tvm::Array<DFPattern>& patterns);
195+
const tvm::Array<DFPattern>& patterns, bool annotate_codegen = false);
190196

191197
/*!
192198
* \brief Group one or multiple composite functions created by FuseOpsByPattern into a new
193199
* function. The new function will be annotated with kCodegen and GlobalSymbol attributes,
194200
* and it is intented to be offloaded to an external backend.
195201
*
196-
* Even if there is only one composite function, or a backend does not benefit from receiving
197-
* larger subgraphs, this pass is required to run for offloading (BYOC) since a composite function
198-
* needs to be wrapped by an outer function that are annotated with "Codegen" and "global_symbol"
199-
* attributes.
200-
*
201202
* \return The Pass.
202203
*/
203204
TVM_DLL Pass MergeCompositeFunctions();

python/tvm/contrib/cutlass/build.py

Lines changed: 191 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
# pylint: disable=invalid-name, dangerous-default-value
17+
# pylint: disable=invalid-name, dangerous-default-value, arguments-differ
1818
"""Driver for partitioning and building a Relay module for CUTLASS offload."""
1919
import logging
2020
import os
2121
import multiprocessing
2222
import tvm
23-
from tvm import runtime, relay
23+
from tvm import runtime, relay, relax
2424
from tvm.contrib.nvcc import get_cuda_version
2525
from tvm._ffi.registry import register_func
2626
from .gen_gemm import CutlassGemmProfiler
@@ -516,6 +516,167 @@ def tune_cutlass_function(
516516
)
517517

518518

519+
def _extract_relax_function_info(f):
520+
signature = {}
521+
522+
for i, arg in enumerate(f.params):
523+
sinfo = arg.struct_info
524+
signature["arg%d_shape" % i] = list(sinfo.shape)
525+
signature["arg%d_dtype" % i] = sinfo.dtype
526+
527+
ret_sinfo = f.ret_struct_info
528+
signature["ret_shape"] = list(ret_sinfo.shape)
529+
signature["ret_dtype"] = ret_sinfo.dtype
530+
531+
op_attrs = {}
532+
533+
def fvisit(e):
534+
nonlocal op_attrs
535+
if isinstance(e, relax.Call) and str(e.op) in ["relax.nn.conv2d"]:
536+
op_attrs = e.attrs
537+
538+
relax.analysis.post_order_visit(f.body, fvisit)
539+
540+
return signature, op_attrs
541+
542+
543+
@relax.expr_functor.mutator
544+
class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
545+
"""A Relax function mutator that tunes and annotates CUTLASS composite functions
546+
with shape, dtype and generated templates.
547+
"""
548+
549+
def __init__(self, mod, conv2d_profiler, options):
550+
super().__init__(mod)
551+
self.options = options
552+
self.conv2d_profiler = conv2d_profiler
553+
554+
def handle_conv2d(self, f, op_type):
555+
"""Tune and annotate a conv2d op."""
556+
signature, op_attrs = _extract_relax_function_info(f)
557+
558+
d_shape = signature["arg0_shape"]
559+
w_shape = signature["arg1_shape"]
560+
out_shape = signature["ret_shape"]
561+
data_dtype = signature["arg0_dtype"]
562+
weight_dtype = signature["arg1_dtype"]
563+
out_dtype = signature["ret_dtype"]
564+
padding = op_attrs["padding"]
565+
strides = op_attrs["strides"]
566+
dilation = op_attrs["dilation"]
567+
conv_kind = ConvKind.Fprop
568+
569+
use_3xtf32 = self.options.get("use_3xtf32", False)
570+
profile_all_alignments = self.options.get("profile_all_alignments", False)
571+
find_first_valid = self.options.get("find_first_valid", True)
572+
use_multiprocessing = self.options.get("use_multiprocessing", True)
573+
split_k_slices = self.options.get("split_k_slices", [1])
574+
575+
op_name, op_def, _ = self.conv2d_profiler.profile(
576+
op_type,
577+
d_shape,
578+
w_shape,
579+
padding,
580+
strides,
581+
dilation,
582+
out_dtype,
583+
data_dtype,
584+
weight_dtype,
585+
use_3xtf32,
586+
conv_kind,
587+
split_k_slices,
588+
profile_all_alignments,
589+
find_first_valid=find_first_valid,
590+
use_multiprocessing=use_multiprocessing,
591+
)
592+
593+
return f.with_attrs(
594+
{
595+
"op_type": op_type,
596+
"arg0_dtype": data_dtype,
597+
"arg1_dtype": weight_dtype,
598+
"ret_dtype": out_dtype,
599+
"arg0_shape": d_shape,
600+
"arg1_shape": w_shape,
601+
"ret_shape": out_shape,
602+
"strides": strides,
603+
"padding": padding,
604+
"dilation": dilation,
605+
"cutlass_op_name": op_name,
606+
"cutlass_op_def": op_def,
607+
}
608+
)
609+
610+
def visit_function_(self, f):
611+
if "Composite" not in f.attrs:
612+
body = super().visit_expr(f.body)
613+
return relax.Function(f.params, body, f.ret_struct_info, f.attrs, f.span)
614+
615+
op_type = f.attrs["Composite"]
616+
617+
if "conv2d" in op_type:
618+
return self.handle_conv2d(f, op_type)
619+
620+
raise ValueError("Unsupported composite {}".format(op_type))
621+
622+
def visit_span(self, span):
623+
return span
624+
625+
626+
@register_func("contrib.cutlass.tune_relax_function")
627+
def profile_relax_function(functions, options):
628+
"""Tune and annotate CUTLASS composite functions with shape, dtype and generated templates."""
629+
tmp_dir = options.get("tmp_dir", "./tmp")
630+
sm = options.get("sm", 80)
631+
conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir)
632+
633+
annotated_functions = []
634+
635+
for f in functions:
636+
annotator = CutlassRelaxFunctionAnnotator(
637+
tvm.IRModule.from_expr(f), conv2d_profiler, options
638+
)
639+
annotated_functions.append(annotator.visit_expr(f))
640+
641+
return annotated_functions
642+
643+
644+
@register_func("contrib.cutlass.compile")
645+
def compile_cutlass_module(c_source_module, options):
646+
"""Compile all CUTLASS kernels in the given C-source module.
647+
648+
Parameters
649+
----------
650+
c_source_module: runtime.Module
651+
A C-source module containing CUTLASS kernels.
652+
653+
options: dict
654+
Compilation options. Currently recognizes
655+
"sm": The target architecture (compute capability), for example 75 or 80 (default: 80)
656+
"threads": The number of threads to use in NVCC parallel compilation (default:
657+
use all logical cores)
658+
"use_fast_math": Whether or not to use faster but approximate arithmetic in some
659+
CUTLASS epilogues (default: False)
660+
661+
Returns
662+
-------
663+
rt_mod : runtime.Module
664+
A runtime module where all cutlass kernels have been compiled.
665+
"""
666+
tmp_dir = options.get("tmp_dir", "./tmp")
667+
defaults = {"sm": 80, "threads": -1, "use_fast_math": False}
668+
compile_config = {key: options.get(key, val) for key, val in defaults.items()}
669+
670+
function_names = c_source_module.get_function("get_func_names")()
671+
compile_options = _get_cutlass_compile_options(**compile_config)
672+
lib_path = os.path.join(tmp_dir, "cutlass.o")
673+
logger.info("Compiling generated CUTLASS code")
674+
c_source_module.export_library(lib_path, workspace_dir=tmp_dir, **compile_options)
675+
676+
# Recover static library
677+
return tvm.runtime.load_static_library(lib_path, function_names)
678+
679+
519680
@register_func("relay.ext.cutlass.compile_for_cutlass")
520681
def compile_for_cutlass(mod, cutlass_target):
521682
"""Given an IRModule with at least one Compiler='cutlass' Relay function, return a
@@ -549,6 +710,7 @@ def compile_for_cutlass(mod, cutlass_target):
549710
key: cutlass_target.attrs.get(key) for key in ["sm", "threads", "use_fast_math"]
550711
}
551712
tmp_dir = cutlass_target.attrs.get("tmp_dir")
713+
compile_config["tmp_dir"] = tmp_dir
552714

553715
# Tune
554716
logger.info("Tuning for CUTLASS")
@@ -558,18 +720,7 @@ def compile_for_cutlass(mod, cutlass_target):
558720
logger.info("Creating CSource module for CUTLASS")
559721
create_c_source_module = tvm._ffi.get_global_func("relay.ext.cutlass.create_c_source_module")
560722
c_module = create_c_source_module(mod)
561-
function_names = c_module.get_function("get_func_names")()
562-
compile_options = _get_cutlass_compile_options(**compile_config)
563-
lib_path = os.path.join(tmp_dir, "cutlass.o")
564-
logger.info("Compiling generated CUTLASS code")
565-
c_module.export_library(lib_path, workspace_dir=tmp_dir, **compile_options)
566-
567-
# Recover static library
568-
logger.info("Loading compiled CUTLASS code")
569-
final_mod = tvm.runtime.load_static_library(lib_path, function_names)
570-
571-
logger.info("Done with CUTLASS compilation")
572-
return final_mod
723+
return compile_cutlass_module(c_module, compile_config)
573724

574725

575726
def finalize_modules(lib, lib_path="compile.so", tmp_dir="./tmp"):
@@ -633,3 +784,29 @@ def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro",
633784
fo.write(code)
634785
lib = tvm.runtime.load_module(lib_path)
635786
return tvm.runtime.vm.Executable.load_exec(code, lib)
787+
788+
789+
def finalize_modules_relax(vm_exec, lib_path="compile.so", tmp_dir="./tmp"):
790+
"""finalize_modules_vm equivalent for Relax VM.
791+
792+
Parameters
793+
----------
794+
vm_exec : vm.Executable
795+
The output from relax.vm.build containing compiled host code and kernels.
796+
797+
lib_path : string
798+
The path to a shared library which will be generated as the result of the build process.
799+
800+
tmp_dir : string
801+
A temporary directory where intermediate compiled artifacts will be stored.
802+
803+
Returns
804+
-------
805+
updated_vm_exec : relax.vm.Executable
806+
The updated VM executable with all compilation and linking completed.
807+
"""
808+
lib_path = os.path.join(tmp_dir, lib_path)
809+
vm_exec.mod.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc")
810+
lib = tvm.runtime.load_module(lib_path)
811+
812+
return relax.vm.Executable(lib)

python/tvm/contrib/cutlass/conv2d_profiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ def __init__(self):
3535
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
3636
{
3737
reinterpret_cast<ImplicitGemm::ElementC*> (workspace.get()),
38-
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
38+
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
3939
},
4040
{
4141
tensor_d.device_data(),
42-
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
42+
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
4343
},
4444
{
4545
tensor_c.device_data(),
46-
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
46+
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
4747
},
4848
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}
4949
);

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def __init__(self, cuda_arch, cutlass_path, binary_prefix):
317317
self.cuda_arch = cuda_arch
318318
self.binary_prefix = binary_prefix
319319
self.cutlass = cutlass_path
320-
self.cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11".format(
320+
self.cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++17".format(
321321
cutlass=cutlass_path
322322
)
323323
self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"

python/tvm/relax/transform/transform.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,9 @@ def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass:
300300
return _ffi_api.FuseOps(fuse_opt_level) # type: ignore
301301

302302

303-
def FuseOpsByPattern(patterns: List[Tuple]) -> tvm.ir.transform.Pass:
303+
def FuseOpsByPattern(
304+
patterns: List[Tuple], annotate_codegen: bool = False
305+
) -> tvm.ir.transform.Pass:
304306
"""Apply pattern matching to each function in the given module, and group matched expressions
305307
into a new function.
306308
@@ -314,26 +316,30 @@ def FuseOpsByPattern(patterns: List[Tuple]) -> tvm.ir.transform.Pass:
314316
The string is the name of the corresponding pattern. It becomes the value of the kComposite
315317
attribute of a fused function after a successful matching.
316318
319+
annotate_codegen : bool
320+
If True, wrap each created composite function with another function, whose body consists
321+
only of a call to the composite function, and annotate the outer function with "Codegen"
322+
and "global_symbol" attributes. The "Codegen" attribute is set as the prefix of the
323+
corresponding pattern name. For example, "dnnl" if the pattern name is "dnnl.conv2d_relu".
324+
325+
This must be True if the created composite functions are intended to be offloaded to
326+
an external backend without using the MergeCompositeFunctions pass.
327+
317328
Returns
318329
-------
319330
ret : tvm.transform.Pass
320331
The registered pass for pattern-based fusion.
321332
322333
"""
323334
pattern_names, df_patterns = zip(*patterns)
324-
return _ffi_api.FuseOpsByPattern(pattern_names, df_patterns) # type: ignore
335+
return _ffi_api.FuseOpsByPattern(pattern_names, df_patterns, annotate_codegen) # type: ignore
325336

326337

327338
def MergeCompositeFunctions() -> tvm.ir.transform.Pass:
328339
"""Group one or multiple composite functions created by FuseOpsByPattern into a new function.
329340
The new function will be annotated with "Codegen" and "global_symbol" attributes, and it
330341
is intented to be offloaded to an external backend.
331342
332-
Even if there is only one composite function, or a backend does not benefit from receiving
333-
larger subgraphs, this pass is required to run for offloading (BYOC) since a composite function
334-
needs to be wrapped by an outer function that are annotated with "Codegen" and "global_symbol"
335-
attributes.
336-
337343
Returns
338344
-------
339345
ret : tvm.transform.Pass

0 commit comments

Comments
 (0)