Skip to content

Commit 7e83cc1

Browse files
authored
[BACKEND] Fix the bug in editable mode and the issue where dynamic libraries cannot be automatically uninstalled. (#123)
* 修改: python/setup.py 修改: python/setup_tools/setup_helper.py 修改: third_party/iluvatar/CMakeLists.txt * modified: CMakeLists.txt modified: python/setup.py modified: python/setup_tools/setup_helper.py modified: third_party/mthreads/CMakeLists.txt * modified: .github/workflows/iluvatar-build-and-test.yml * modified: .github/workflows/iluvatar-build-and-test.yml modified: .github/workflows/mthreads-build-and-test.yml new file: test/bin/iluvatar/add_kernel.ttgir new file: test/bin/iluvatar/add_kernel.ttir new file: test/bin/mthreads/add_kernel.ttgir new file: test/bin/mthreads/add_kernel.ttir * modified: python/setup.py modified: python/setup_tools/setup_helper.py * 修改: .github/workflows/iluvatar-build-and-test.yml 修改: .github/workflows/mthreads-build-and-test.yml 修改: python/setup_tools/setup_helper.py * 修改: python/setup.py * 修改: .github/workflows/iluvatar-build-and-test.yml
1 parent a52bddf commit 7e83cc1

File tree

11 files changed

+241
-28
lines changed

11 files changed

+241
-28
lines changed

.github/workflows/iluvatar-build-and-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ jobs:
5959
shell: bash
6060
run: |
6161
CUDA_VISIBLE_DEVICES=15 python3 -m pytest -s third_party/iluvatar/python/test/unit
62+
./python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt --pass-pipeline='builtin.module(convert-triton-to-tritongpu{target="cuda:CC" num-warps=4 threads-per-warp=32 num-ctas=1})' ./test/bin/iluvatar/add_kernel.ttir
63+
./python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt --convert-arith-to-llvm ./test/bin/iluvatar/add_kernel.ttgir
6264
cd python/tutorials
6365
python3 01-vector-add.py
6466
python3 04-low-memory-dropout.py

.github/workflows/mthreads-build-and-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,6 @@ jobs:
6363
- name: FlagTree Test on Mthreads
6464
shell: bash
6565
run: |
66+
./python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt --pass-pipeline='builtin.module(convert-triton-to-tritongpu{target="cuda:CC" num-warps=4 threads-per-warp=32 num-ctas=1})' ./test/bin/mthreads/add_kernel.ttir
67+
./python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt --convert-ub-to-llvm ./test/bin/mthreads/add_kernel.ttgir
6668
python3 -m pytest -s third_party/mthreads/python/test/unit

CMakeLists.txt

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,23 @@ elseif(FLAGTREE_BACKEND STREQUAL "iluvatar")
2222
remove_definitions(-D_GLIBCXX_USE_CXX11_ABI=1)
2323
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
2424
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
25-
add_definitions(-DDEFAULT_PLUGIN_DIR="${Python3_SITELIB}/triton/_C")
25+
if(EDITABLE_MODE)
26+
set (DEFAULT_PLUGIN_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/iluvatar")
27+
else()
28+
set (DEFAULT_PLUGIN_DIR "${Python3_SITELIB}/triton/_C")
29+
endif()
30+
add_definitions(-DDEFAULT_PLUGIN_DIR="${DEFAULT_PLUGIN_DIR}")
2631
elseif(FLAGTREE_BACKEND STREQUAL "mthreads")
2732
set(ENV{PATH} "$ENV{LLVM_SYSPATH}/bin:$ENV{PATH}")
2833
set(CMAKE_C_COMPILER clang)
2934
set(CMAKE_CXX_COMPILER clang++)
3035
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
31-
add_definitions(-DDEFAULT_PLUGIN_DIR="${Python3_SITELIB}/triton/_C")
32-
# set(ENV{FLAGTREE_PLUGIN} $ENV{FLAGTREE_BACKEND})
36+
if(EDITABLE_MODE)
37+
set (DEFAULT_PLUGIN_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/mthreads")
38+
else()
39+
set (DEFAULT_PLUGIN_DIR "${Python3_SITELIB}/triton/_C")
40+
endif()
41+
add_definitions(-DDEFAULT_PLUGIN_DIR="${DEFAULT_PLUGIN_DIR}")
3342
elseif(FLAGTREE_BACKEND STREQUAL "hcu")
3443
add_definitions(-D__HCU__)
3544
endif()

python/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ def run(self):
555555

556556
package_data = {
557557
"triton/tools": package_data_tools,
558+
"": ["*TritonPlugin.so"],
558559
**{f"triton/backends/{b.name}": b.package_data
559560
for b in backends},
560561
}

python/setup_tools/setup_helper.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,17 @@ def install_extension(*args, **kargs):
3434

3535

3636
def get_backend_cmake_args(*args, **kargs):
37+
if "editable_wheel" in sys.argv:
38+
editable = True
39+
else:
40+
editable = False
41+
# lit is used by the test suite
42+
handle_plugin_backend(editable)
3743
try:
38-
return activated_module.get_backend_cmake_args(*args, **kargs)
44+
cmake_args = activated_module.get_backend_cmake_args(*args, **kargs)
45+
if "editable_wheel" in sys.argv:
46+
cmake_args += ["-DEDITABLE_MODE=ON"]
47+
return cmake_args
3948
except Exception:
4049
return []
4150

@@ -308,6 +317,17 @@ def handle_flagtree_backend():
308317
ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/"
309318

310319

320+
def handle_plugin_backend(editable):
321+
if flagtree_backend in ["iluvatar", "mthreads"] and editable is False:
322+
src_plugin_path = str(
323+
os.getenv("HOME")) + "/.flagtree/" + flagtree_backend + "/" + flagtree_backend + "TritonPlugin.so"
324+
dst_plugin_dir = sysconfig.get_paths()['purelib'] + "/triton/_C"
325+
if not os.path.exists(dst_plugin_dir):
326+
os.makedirs(dst_plugin_dir)
327+
dst_plugin_path = dst_plugin_dir + "/" + flagtree_backend + "TritonPlugin.so"
328+
shutil.copy(src_plugin_path, dst_plugin_path)
329+
330+
311331
def set_env(env_dict: dict):
312332
for env_k, env_v in env_dict.items():
313333
os.environ[env_k] = str(env_v)
@@ -423,12 +443,3 @@ def check_env(env_val):
423443
pre_hock=lambda: check_env('LLVM_SYSPATH'),
424444
post_hock=set_llvm_env,
425445
)
426-
427-
if flagtree_backend in ["iluvatar", "mthreads"]:
428-
src_plugin_path = str(
429-
os.getenv("HOME")) + "/.flagtree/" + flagtree_backend + "/" + flagtree_backend + "TritonPlugin.so"
430-
dst_plugin_dir = sysconfig.get_paths()['purelib'] + "/triton/_C"
431-
if not os.path.exists(dst_plugin_dir):
432-
os.makedirs(dst_plugin_dir)
433-
dst_plugin_path = dst_plugin_dir + "/" + flagtree_backend + "TritonPlugin.so"
434-
shutil.copy(src_plugin_path, dst_plugin_path)

test/bin/iluvatar/add_kernel.ttgir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0], loadType = 0, smeWarpsPerCTA = [0]}>
2+
#loc = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":28:0)
3+
module attributes {"triton_gpu.dot.num-stages" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:71", "triton_gpu.threads-per-warp" = 64 : i32} {
4+
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":28:0)) attributes {noinline = false} {
5+
%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
6+
%0 = tt.get_program_id x : i32 loc(#loc2)
7+
%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
8+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> loc(#loc4)
9+
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> loc(#loc5)
10+
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> loc(#loc5)
11+
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> loc(#loc6)
12+
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> loc(#loc6)
13+
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc7)
14+
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc7)
15+
%9 = tt.load %8, %6 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> loc(#loc8)
16+
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc9)
17+
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc9)
18+
%12 = tt.load %11, %6 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> loc(#loc10)
19+
%13 = arith.addf %9, %12 : tensor<1024xf32, #blocked> loc(#loc11)
20+
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc12)
21+
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc12)
22+
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc13)
23+
tt.return loc(#loc14)
24+
} loc(#loc)
25+
} loc(#loc)
26+
#loc1 = loc(unknown)
27+
#loc2 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":37:24)
28+
#loc3 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":42:24)
29+
#loc4 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":43:41)
30+
#loc5 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":43:28)
31+
#loc6 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":45:21)
32+
#loc7 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":48:24)
33+
#loc8 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":48:16)
34+
#loc9 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":49:24)
35+
#loc10 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":49:16)
36+
#loc11 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":50:17)
37+
#loc12 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":52:26)
38+
#loc13 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":52:35)
39+
#loc14 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":52:4)

test/bin/iluvatar/add_kernel.ttir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#loc = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":28:0)
2+
module {
3+
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":28:0)) attributes {noinline = false} {
4+
%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
5+
%0 = tt.get_program_id x : i32 loc(#loc2)
6+
%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
7+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)
8+
%3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5)
9+
%4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5)
10+
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc6)
11+
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc6)
12+
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc7)
13+
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc7)
14+
%9 = tt.load %8, %6 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32> loc(#loc8)
15+
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc9)
16+
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc9)
17+
%12 = tt.load %11, %6 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32> loc(#loc10)
18+
%13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc11)
19+
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc12)
20+
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc12)
21+
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc13)
22+
tt.return loc(#loc14)
23+
} loc(#loc)
24+
} loc(#loc)
25+
#loc1 = loc(unknown)
26+
#loc2 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":37:24)
27+
#loc3 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":42:24)
28+
#loc4 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":43:41)
29+
#loc5 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":43:28)
30+
#loc6 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":45:21)
31+
#loc7 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":48:24)
32+
#loc8 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":48:16)
33+
#loc9 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":49:24)
34+
#loc10 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":49:16)
35+
#loc11 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":50:17)
36+
#loc12 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":52:26)
37+
#loc13 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":52:35)
38+
#loc14 = loc("/usr/local/corex-4.3.0.20250707/flagtree/python/tutorials/01-vector-add.py":52:4)

test/bin/mthreads/add_kernel.ttgir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
2+
#loc = loc("/root/flagtree/python/tutorials/01-vector-add.py":28:0)
3+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "musa:31", "triton_gpu.threads-per-warp" = 32 : i32} {
4+
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/root/flagtree/python/tutorials/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/root/flagtree/python/tutorials/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/root/flagtree/python/tutorials/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/flagtree/python/tutorials/01-vector-add.py":28:0)) attributes {noinline = false} {
5+
%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
6+
%0 = tt.get_program_id x : i32 loc(#loc2)
7+
%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
8+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> loc(#loc4)
9+
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> loc(#loc5)
10+
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> loc(#loc5)
11+
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> loc(#loc6)
12+
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> loc(#loc6)
13+
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc7)
14+
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc7)
15+
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc8)
16+
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc9)
17+
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc9)
18+
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc10)
19+
%13 = arith.addf %9, %12 : tensor<1024xf32, #blocked> loc(#loc11)
20+
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc12)
21+
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc12)
22+
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc13)
23+
tt.return loc(#loc14)
24+
} loc(#loc)
25+
} loc(#loc)
26+
#loc1 = loc(unknown)
27+
#loc2 = loc("/root/flagtree/python/tutorials/01-vector-add.py":37:24)
28+
#loc3 = loc("/root/flagtree/python/tutorials/01-vector-add.py":42:24)
29+
#loc4 = loc("/root/flagtree/python/tutorials/01-vector-add.py":43:41)
30+
#loc5 = loc("/root/flagtree/python/tutorials/01-vector-add.py":43:28)
31+
#loc6 = loc("/root/flagtree/python/tutorials/01-vector-add.py":45:21)
32+
#loc7 = loc("/root/flagtree/python/tutorials/01-vector-add.py":48:24)
33+
#loc8 = loc("/root/flagtree/python/tutorials/01-vector-add.py":48:16)
34+
#loc9 = loc("/root/flagtree/python/tutorials/01-vector-add.py":49:24)
35+
#loc10 = loc("/root/flagtree/python/tutorials/01-vector-add.py":49:16)
36+
#loc11 = loc("/root/flagtree/python/tutorials/01-vector-add.py":50:17)
37+
#loc12 = loc("/root/flagtree/python/tutorials/01-vector-add.py":52:26)
38+
#loc13 = loc("/root/flagtree/python/tutorials/01-vector-add.py":52:35)
39+
#loc14 = loc("/root/flagtree/python/tutorials/01-vector-add.py":52:4)

test/bin/mthreads/add_kernel.ttir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#loc = loc("/root/flagtree/python/tutorials/01-vector-add.py":28:0)
2+
module {
3+
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/root/flagtree/python/tutorials/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/root/flagtree/python/tutorials/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/root/flagtree/python/tutorials/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/flagtree/python/tutorials/01-vector-add.py":28:0)) attributes {noinline = false} {
4+
%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
5+
%0 = tt.get_program_id x : i32 loc(#loc2)
6+
%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
7+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)
8+
%3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5)
9+
%4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5)
10+
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc6)
11+
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc6)
12+
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc7)
13+
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc7)
14+
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc8)
15+
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc9)
16+
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc9)
17+
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc10)
18+
%13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc11)
19+
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc12)
20+
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc12)
21+
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc13)
22+
tt.return loc(#loc14)
23+
} loc(#loc)
24+
} loc(#loc)
25+
#loc1 = loc(unknown)
26+
#loc2 = loc("/root/flagtree/python/tutorials/01-vector-add.py":37:24)
27+
#loc3 = loc("/root/flagtree/python/tutorials/01-vector-add.py":42:24)
28+
#loc4 = loc("/root/flagtree/python/tutorials/01-vector-add.py":43:41)
29+
#loc5 = loc("/root/flagtree/python/tutorials/01-vector-add.py":43:28)
30+
#loc6 = loc("/root/flagtree/python/tutorials/01-vector-add.py":45:21)
31+
#loc7 = loc("/root/flagtree/python/tutorials/01-vector-add.py":48:24)
32+
#loc8 = loc("/root/flagtree/python/tutorials/01-vector-add.py":48:16)
33+
#loc9 = loc("/root/flagtree/python/tutorials/01-vector-add.py":49:24)
34+
#loc10 = loc("/root/flagtree/python/tutorials/01-vector-add.py":49:16)
35+
#loc11 = loc("/root/flagtree/python/tutorials/01-vector-add.py":50:17)
36+
#loc12 = loc("/root/flagtree/python/tutorials/01-vector-add.py":52:26)
37+
#loc13 = loc("/root/flagtree/python/tutorials/01-vector-add.py":52:35)
38+
#loc14 = loc("/root/flagtree/python/tutorials/01-vector-add.py":52:4)

third_party/iluvatar/CMakeLists.txt

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,42 @@
11
add_subdirectory(include)
22
add_subdirectory(lib)
33
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
4-
set (ILUVATAR_PLUGIN_DIR "${Python3_SITELIB}/triton/_C")
4+
option(EDITABLE_MODE "Build in developer (editable) mode" OFF)
5+
if(EDITABLE_MODE)
6+
set(ILUVATAR_PLUGIN_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
7+
else()
8+
set (ILUVATAR_PLUGIN_DIR "${Python3_SITELIB}/triton/_C")
9+
endif()
510
if(TRITON_BUILD_PYTHON_MODULE)
611
if(FLAGTREE_PLUGIN)
712
add_subdirectory(plugin)
813
add_triton_plugin(TritonILUVATAR
914
SHARED_LIB iluvatarTritonPlugin
1015
)
1116
else()
12-
find_library(iluvatarTritonPluginLib
17+
if(EDITABLE_MODE)
18+
find_library(iluvatarTritonPluginLib
1319
NAMES
1420
iluvatarTritonPlugin.so
1521
PATHS
16-
${ILUVATAR_PLUGIN_DIR}
22+
${CMAKE_CURRENT_SOURCE_DIR}
1723
REQUIRED
18-
)
19-
add_triton_plugin(TritonILUVATAR
20-
SHARED_LIB ${iluvatarTritonPluginLib}
21-
)
24+
)
25+
add_triton_plugin(TritonILUVATAR
26+
SHARED_LIB ${iluvatarTritonPluginLib}
27+
)
28+
else()
29+
find_library(iluvatarTritonPluginLib
30+
NAMES
31+
iluvatarTritonPlugin.so
32+
PATHS
33+
${ILUVATAR_PLUGIN_DIR}
34+
REQUIRED
35+
)
36+
add_triton_plugin(TritonILUVATAR
37+
SHARED_LIB ${iluvatarTritonPluginLib}
38+
)
39+
endif()
2240
endif()
2341
endif()
2442

0 commit comments

Comments
 (0)