Skip to content

Commit b5f36ad

Browse files
committed
tilelang frontend v2
1 parent 50e789d commit b5f36ad

File tree

7 files changed

+1182
-5
lines changed

7 files changed

+1182
-5
lines changed

examples/gdn/example_chunk_o_bwd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ def kernel(
256256
# for i_kv in T.Parallel(block_DK * block_DV):
257257
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
258258
for i_kv in T.Parallel(block_DK * block_DV):
259-
i_k, i_v = i_kv // block_DV, i_kv % block_DV
260-
dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v]
259+
i_k, i_v_1 = i_kv // block_DV, i_kv % block_DV
260+
dg_last_fragment[i_kv] = h_shared[i_k, i_v_1] * dh_shared[i_k, i_v_1]
261261
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
262262
dg_last_local[0] += dg_last_fragment_scalar[0]
263263

tilelang/language/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
# upstream tir script is fully compatible
99
from tvm.script.parser.tir import *
1010
from . import overrides as _overrides # noqa: F401
11-
from .tir import (
12-
prim_func, # noqa: F401
13-
)
11+
12+
# from .tir import prim_func, macro, # noqa: F401
13+
from .v2 import prim_func, macro # noqa: F401
1414
from .tir.ir import * # noqa: F401
1515
from tilelang.layout import Layout, Fragment # noqa: F401
1616
from .proxy import (

tilelang/language/v2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .builder import prim_func, macro # noqa: F401

0 commit comments

Comments
 (0)