1616# under the License.
1717
1818if (USE_CUDA AND USE_CUTLASS)
19- tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc)
19+ set (CUTLASS_GEN_COND "$<AND:$<BOOL:${USE_CUDA} >,$<BOOL:${USE_CUTLASS} >>" )
20+ set (CUTLASS_RUNTIME_OBJS "" )
21+
22+ tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC
23+ src/relay/backend/contrib/cutlass/*.cc
24+ src/relax/backend/contrib/cutlass/*.cc
25+ )
2026 list (APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC} )
2127
2228 set (FPA_INTB_GEMM_TVM_BINDING ON )
2329 set (FPA_INTB_GEMM_TVM_HOME ${PROJECT_SOURCE_DIR} )
2430
25- set (CUTLASS_DIR ${PROJECT_SOURCE_DIR} /3rdparty/ cutlass)
31+ ### Build cutlass runtime objects for fpA_intB_gemm using its cutlass submodule
2632 add_subdirectory (${PROJECT_SOURCE_DIR} /3rdparty/cutlass_fpA_intB_gemm)
33+ target_include_directories (fpA_intB_gemm PRIVATE
34+ ${PROJECT_SOURCE_DIR} /3rdparty/cutlass_fpA_intB_gemm
35+ ${PROJECT_SOURCE_DIR} /3rdparty/cutlass_fpA_intB_gemm/cutlass/include
36+ )
37+ set (CUTLASS_FPA_INTB_RUNTIME_SRCS "" )
38+ list (APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS src/runtime/contrib/cutlass/moe_gemm.cc)
39+ list (APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
40+ add_library (fpA_intB_cutlass_objs OBJECT ${CUTLASS_FPA_INTB_RUNTIME_SRCS} )
41+ target_include_directories (fpA_intB_cutlass_objs PRIVATE
42+ ${PROJECT_SOURCE_DIR} /3rdparty/cutlass_fpA_intB_gemm/cutlass/include
43+ )
44+ list (APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND} :$<TARGET_OBJECTS:fpA_intB_cutlass_objs>>" )
45+
46+ ### Build cutlass runtime objects for flash attention using its cutlass submodule
2747 add_subdirectory (${PROJECT_SOURCE_DIR} /3rdparty/libflash_attn)
48+ target_include_directories (flash_attn PRIVATE
49+ ${PROJECT_SOURCE_DIR} /3rdparty/libflash_attn
50+ ${PROJECT_SOURCE_DIR} /3rdparty/libflash_attn/cutlass/include
51+ )
52+ set (CUTLASS_FLASH_ATTN_RUNTIME_SRCS "" )
53+ list (APPEND CUTLASS_FLASH_ATTN_RUNTIME_SRCS src/runtime/contrib/cutlass/flash_decoding.cu)
54+ add_library (flash_attn_cutlass_objs OBJECT ${CUTLASS_FLASH_ATTN_RUNTIME_SRCS} )
55+ target_include_directories (flash_attn_cutlass_objs PRIVATE
56+ ${PROJECT_SOURCE_DIR} /3rdparty/libflash_attn/cutlass/include
57+ )
58+ list (APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND} :$<TARGET_OBJECTS:flash_attn_cutlass_objs>>" )
59+
60+ ### Build cutlass runtime objects using TVM's 3rdparty/cutlass submodule
61+ set (CUTLASS_DIR ${PROJECT_SOURCE_DIR} /3rdparty/cutlass)
62+ set (TVM_CUTLASS_RUNTIME_SRCS "" )
63+ if (CMAKE_CUDA_ARCHITECTURES MATCHES "90" )
64+ list (APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_fp8_gemm.cu)
65+ endif ()
66+ if (TVM_CUTLASS_RUNTIME_SRCS)
67+ add_library (tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS} )
68+ target_include_directories (tvm_cutlass_objs PRIVATE ${CUTLASS_DIR} /include )
69+ list (APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND} :$<TARGET_OBJECTS:tvm_cutlass_objs>>" )
70+ endif ()
2871
29- include_directories (3rdparty/cutlass_fpA_intB_gemm
30- 3rdparty/cutlass_fpA_intB_gemm/cutlass/include ) # FIXME
31- list (APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
32- list (APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/moe_gemm.cc)
33- list (APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/flash_decoding.cu)
72+ ### Add cutlass objects to list of TVM runtime extension objs
73+ list (APPEND TVM_RUNTIME_EXT_OBJS "${CUTLASS_RUNTIME_OBJS} " )
3474
3575 message (STATUS "Build with CUTLASS" )
36- endif ()
76+ endif ()
0 commit comments