1414# KIND, either express or implied. See the License for the
1515# specific language governing permissions and limitations
1616# under the License.
17- # pylint: disable=invalid-name, dangerous-default-value
17+ # pylint: disable=invalid-name, dangerous-default-value, arguments-differ
1818"""Driver for partitioning and building a Relay module for CUTLASS offload."""
1919import logging
2020import os
2121import multiprocessing
2222import tvm
23- from tvm import runtime , relay
23+ from tvm import runtime , relay , relax
2424from tvm .contrib .nvcc import get_cuda_version
2525from tvm ._ffi .registry import register_func
2626from .gen_gemm import CutlassGemmProfiler
@@ -516,6 +516,167 @@ def tune_cutlass_function(
516516 )
517517
518518
519+ def _extract_relax_function_info (f ):
520+ signature = {}
521+
522+ for i , arg in enumerate (f .params ):
523+ sinfo = arg .struct_info
524+ signature ["arg%d_shape" % i ] = list (sinfo .shape )
525+ signature ["arg%d_dtype" % i ] = sinfo .dtype
526+
527+ ret_sinfo = f .ret_struct_info
528+ signature ["ret_shape" ] = list (ret_sinfo .shape )
529+ signature ["ret_dtype" ] = ret_sinfo .dtype
530+
531+ op_attrs = {}
532+
533+ def fvisit (e ):
534+ nonlocal op_attrs
535+ if isinstance (e , relax .Call ) and str (e .op ) in ["relax.nn.conv2d" ]:
536+ op_attrs = e .attrs
537+
538+ relax .analysis .post_order_visit (f .body , fvisit )
539+
540+ return signature , op_attrs
541+
542+
543+ @relax .expr_functor .mutator
544+ class CutlassRelaxFunctionAnnotator (relax .PyExprMutator ):
545+ """A Relax function mutator that tunes and annotates CUTLASS composite functions
546+ with shape, dtype and generated templates.
547+ """
548+
549+ def __init__ (self , mod , conv2d_profiler , options ):
550+ super ().__init__ (mod )
551+ self .options = options
552+ self .conv2d_profiler = conv2d_profiler
553+
554+ def handle_conv2d (self , f , op_type ):
555+ """Tune and annotate a conv2d op."""
556+ signature , op_attrs = _extract_relax_function_info (f )
557+
558+ d_shape = signature ["arg0_shape" ]
559+ w_shape = signature ["arg1_shape" ]
560+ out_shape = signature ["ret_shape" ]
561+ data_dtype = signature ["arg0_dtype" ]
562+ weight_dtype = signature ["arg1_dtype" ]
563+ out_dtype = signature ["ret_dtype" ]
564+ padding = op_attrs ["padding" ]
565+ strides = op_attrs ["strides" ]
566+ dilation = op_attrs ["dilation" ]
567+ conv_kind = ConvKind .Fprop
568+
569+ use_3xtf32 = self .options .get ("use_3xtf32" , False )
570+ profile_all_alignments = self .options .get ("profile_all_alignments" , False )
571+ find_first_valid = self .options .get ("find_first_valid" , True )
572+ use_multiprocessing = self .options .get ("use_multiprocessing" , True )
573+ split_k_slices = self .options .get ("split_k_slices" , [1 ])
574+
575+ op_name , op_def , _ = self .conv2d_profiler .profile (
576+ op_type ,
577+ d_shape ,
578+ w_shape ,
579+ padding ,
580+ strides ,
581+ dilation ,
582+ out_dtype ,
583+ data_dtype ,
584+ weight_dtype ,
585+ use_3xtf32 ,
586+ conv_kind ,
587+ split_k_slices ,
588+ profile_all_alignments ,
589+ find_first_valid = find_first_valid ,
590+ use_multiprocessing = use_multiprocessing ,
591+ )
592+
593+ return f .with_attrs (
594+ {
595+ "op_type" : op_type ,
596+ "arg0_dtype" : data_dtype ,
597+ "arg1_dtype" : weight_dtype ,
598+ "ret_dtype" : out_dtype ,
599+ "arg0_shape" : d_shape ,
600+ "arg1_shape" : w_shape ,
601+ "ret_shape" : out_shape ,
602+ "strides" : strides ,
603+ "padding" : padding ,
604+ "dilation" : dilation ,
605+ "cutlass_op_name" : op_name ,
606+ "cutlass_op_def" : op_def ,
607+ }
608+ )
609+
610+ def visit_function_ (self , f ):
611+ if "Composite" not in f .attrs :
612+ body = super ().visit_expr (f .body )
613+ return relax .Function (f .params , body , f .ret_struct_info , f .attrs , f .span )
614+
615+ op_type = f .attrs ["Composite" ]
616+
617+ if "conv2d" in op_type :
618+ return self .handle_conv2d (f , op_type )
619+
620+ raise ValueError ("Unsupported composite {}" .format (op_type ))
621+
622+ def visit_span (self , span ):
623+ return span
624+
625+
626+ @register_func ("contrib.cutlass.tune_relax_function" )
627+ def profile_relax_function (functions , options ):
628+ """Tune and annotate CUTLASS composite functions with shape, dtype and generated templates."""
629+ tmp_dir = options .get ("tmp_dir" , "./tmp" )
630+ sm = options .get ("sm" , 80 )
631+ conv2d_profiler = CutlassConv2DProfiler (sm , _get_cutlass_path (), tmp_dir )
632+
633+ annotated_functions = []
634+
635+ for f in functions :
636+ annotator = CutlassRelaxFunctionAnnotator (
637+ tvm .IRModule .from_expr (f ), conv2d_profiler , options
638+ )
639+ annotated_functions .append (annotator .visit_expr (f ))
640+
641+ return annotated_functions
642+
643+
644+ @register_func ("contrib.cutlass.compile" )
645+ def compile_cutlass_module (c_source_module , options ):
646+ """Compile all CUTLASS kernels in the given C-source module.
647+
648+ Parameters
649+ ----------
650+ c_source_module: runtime.Module
651+ A C-source module containing CUTLASS kernels.
652+
653+ options: dict
654+ Compilation options. Currently recognizes
655+ "sm": The target architecture (compute capability), for example 75 or 80 (default: 80)
656+ "threads": The number of threads to use in NVCC parallel compilation (default:
657+ use all logical cores)
658+ "use_fast_math": Whether or not to use faster but approximate arithmetic in some
659+ CUTLASS epilogues (default: False)
660+
661+ Returns
662+ -------
663+ rt_mod : runtime.Module
664+ A runtime module where all cutlass kernels have been compiled.
665+ """
666+ tmp_dir = options .get ("tmp_dir" , "./tmp" )
667+ defaults = {"sm" : 80 , "threads" : - 1 , "use_fast_math" : False }
668+ compile_config = {key : options .get (key , val ) for key , val in defaults .items ()}
669+
670+ function_names = c_source_module .get_function ("get_func_names" )()
671+ compile_options = _get_cutlass_compile_options (** compile_config )
672+ lib_path = os .path .join (tmp_dir , "cutlass.o" )
673+ logger .info ("Compiling generated CUTLASS code" )
674+ c_source_module .export_library (lib_path , workspace_dir = tmp_dir , ** compile_options )
675+
676+ # Recover static library
677+ return tvm .runtime .load_static_library (lib_path , function_names )
678+
679+
519680@register_func ("relay.ext.cutlass.compile_for_cutlass" )
520681def compile_for_cutlass (mod , cutlass_target ):
521682 """Given an IRModule with at least one Compiler='cutlass' Relay function, return a
@@ -549,6 +710,7 @@ def compile_for_cutlass(mod, cutlass_target):
549710 key : cutlass_target .attrs .get (key ) for key in ["sm" , "threads" , "use_fast_math" ]
550711 }
551712 tmp_dir = cutlass_target .attrs .get ("tmp_dir" )
713+ compile_config ["tmp_dir" ] = tmp_dir
552714
553715 # Tune
554716 logger .info ("Tuning for CUTLASS" )
@@ -558,18 +720,7 @@ def compile_for_cutlass(mod, cutlass_target):
558720 logger .info ("Creating CSource module for CUTLASS" )
559721 create_c_source_module = tvm ._ffi .get_global_func ("relay.ext.cutlass.create_c_source_module" )
560722 c_module = create_c_source_module (mod )
561- function_names = c_module .get_function ("get_func_names" )()
562- compile_options = _get_cutlass_compile_options (** compile_config )
563- lib_path = os .path .join (tmp_dir , "cutlass.o" )
564- logger .info ("Compiling generated CUTLASS code" )
565- c_module .export_library (lib_path , workspace_dir = tmp_dir , ** compile_options )
566-
567- # Recover static library
568- logger .info ("Loading compiled CUTLASS code" )
569- final_mod = tvm .runtime .load_static_library (lib_path , function_names )
570-
571- logger .info ("Done with CUTLASS compilation" )
572- return final_mod
723+ return compile_cutlass_module (c_module , compile_config )
573724
574725
575726def finalize_modules (lib , lib_path = "compile.so" , tmp_dir = "./tmp" ):
@@ -633,3 +784,29 @@ def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro",
633784 fo .write (code )
634785 lib = tvm .runtime .load_module (lib_path )
635786 return tvm .runtime .vm .Executable .load_exec (code , lib )
787+
788+
789+ def finalize_modules_relax (vm_exec , lib_path = "compile.so" , tmp_dir = "./tmp" ):
790+ """finalize_modules_vm equivalent for Relax VM.
791+
792+ Parameters
793+ ----------
794+ vm_exec : vm.Executable
795+ The output from relax.vm.build containing compiled host code and kernels.
796+
797+ lib_path : string
798+ The path to a shared library which will be generated as the result of the build process.
799+
800+ tmp_dir : string
801+ A temporary directory where intermediate compiled artifacts will be stored.
802+
803+ Returns
804+ -------
805+ updated_vm_exec : relax.vm.Executable
806+ The updated VM executable with all compilation and linking completed.
807+ """
808+ lib_path = os .path .join (tmp_dir , lib_path )
809+ vm_exec .mod .export_library (lib_path , workspace_dir = tmp_dir , cc = "nvcc" )
810+ lib = tvm .runtime .load_module (lib_path )
811+
812+ return relax .vm .Executable (lib )
0 commit comments