forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathonnxruntime_compile_triton_kernel.cmake
35 lines (30 loc) · 1.42 KB
/
onnxruntime_compile_triton_kernel.cmake
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
find_package(Python3 COMPONENTS Interpreter REQUIRED)
# set all triton kernel ops that need to be compiled
if(onnxruntime_USE_ROCM)
set(triton_kernel_scripts
"onnxruntime/core/providers/rocm/math/softmax_triton.py"
"onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py"
)
endif()
function(compile_triton_kernel out_triton_kernel_obj_file out_triton_kernel_header_dir)
# compile triton kernel, generate .a and .h files
set(triton_kernel_compiler "${REPO_ROOT}/tools/ci_build/compile_triton.py")
set(out_dir "${CMAKE_CURRENT_BINARY_DIR}/triton_kernels")
set(out_obj_file "${out_dir}/triton_kernel_infos.a")
set(header_file "${out_dir}/triton_kernel_infos.h")
list(TRANSFORM triton_kernel_scripts PREPEND "${REPO_ROOT}/")
add_custom_command(
OUTPUT ${out_obj_file} ${header_file}
COMMAND Python3::Interpreter ${triton_kernel_compiler}
--header ${header_file}
--script_files ${triton_kernel_scripts}
--obj_file ${out_obj_file}
DEPENDS ${triton_kernel_scripts} ${triton_kernel_compiler}
COMMENT "Triton compile generates: ${out_obj_file}"
)
add_custom_target(onnxruntime_triton_kernel DEPENDS ${out_obj_file} ${header_file})
set(${out_triton_kernel_obj_file} ${out_obj_file} PARENT_SCOPE)
set(${out_triton_kernel_header_dir} ${out_dir} PARENT_SCOPE)
endfunction()