Skip to content

Commit fc29eea

Browse files
committed
minor fix
1 parent ece1dc3 commit fc29eea

File tree

4 files changed

+306
-56
lines changed

4 files changed

+306
-56
lines changed

a.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from tilelang import tvm
2+
import torch
3+
4+
vt = tvm.runtime.convert(torch.float32)
5+
6+
tvm.DataType('float32')

stubgen.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import ast
2+
from logging.config import valid_ident
3+
import re
4+
# from rich import print
5+
6+
from argparse import ArgumentParser
7+
8+
with open('tilelang/language/tir/op.py') as f:
9+
data = f.read()
10+
11+
tree = ast.parse(data)
12+
13+
def convert_tree(x):
14+
result = {}
15+
for fname, value in ast.iter_fields(x):
16+
if isinstance(value, list):
17+
result[fname] = [convert_tree(v) if isinstance(v, ast.AST) else v for v in value]
18+
elif isinstance(value, ast.AST):
19+
result[fname] = convert_tree(value)
20+
else:
21+
result[fname] = value
22+
return result
23+
24+
# print(convert_tree(tree))
25+
26+
funcs = {}
27+
28+
subst = {
29+
'Expr': 'PrimExpr',
30+
'UIntImm': 'IntImm',
31+
'tvm.Expr': 'PrimExpr'
32+
}
33+
34+
for fdef in tree.body:
35+
if not isinstance(fdef, ast.FunctionDef):
36+
continue
37+
if not isinstance(fdef.body[0], ast.Expr):
38+
continue
39+
value = fdef.body[0].value
40+
if not isinstance(value, ast.Constant):
41+
continue
42+
data = value.value
43+
if not isinstance(data, str):
44+
continue
45+
lines = data.splitlines()
46+
ty = None
47+
annots = {}
48+
for i, line in enumerate(lines):
49+
if i > 0 and re.fullmatch(r' \s*----+', line):
50+
annot = lines[i - 1]
51+
ty = None
52+
if annot == ' Parameters':
53+
ty = 'param'
54+
if annot == ' Returns':
55+
ty = 'return'
56+
if mat := re.fullmatch(r'\s+([A-Za-z_][A-Za-z0-9_]*)\s*:\s+(.*)', line):
57+
name, val = mat.groups()
58+
val = subst.get(val, val)
59+
if ty == 'param':
60+
annots[name] = val
61+
if ty == 'return':
62+
annots['return'] = val
63+
64+
pe_arg = []
65+
span_arg = []
66+
other_arg = []
67+
for args in fdef.args.args:
68+
if args.arg in annots:
69+
annot = annots[args.arg]
70+
if annot == 'PrimExpr':
71+
pe_arg.append(args.arg)
72+
elif annot == 'Optional[Span]':
73+
span_arg.append(args.arg)
74+
else:
75+
other_arg.append(args.arg)
76+
try:
77+
args.annotation = ast.parse(annot).body[0].value
78+
except Exception as e:
79+
print(annot, repr(e))
80+
else:
81+
other_arg.append(args.arg)
82+
if 'return' in annots:
83+
try:
84+
fdef.returns = ast.parse(annots['return']).body[0].value
85+
except Exception as e:
86+
print(annots['return'], repr(e))
87+
if annots.get('return', None) == 'PrimExpr' and not other_arg:
88+
print('UT Prim: ', fdef.name)
89+
Tvar = ast.parse('_T').body[0].value
90+
for args in fdef.args.args:
91+
if args.arg in pe_arg:
92+
args.annotation = Tvar
93+
fdef.returns = Tvar
94+
fdef.body = [ast.parse('...')]
95+
# funcs.append(fdef)
96+
funcs[fdef.name] = fdef
97+
98+
# tree.body = funcs
99+
# print(ast.unparse(tree))
100+
101+
with open('tilelang/language/tir/ir.py') as f:
102+
data = f.read()
103+
104+
all_funcs = []
105+
106+
for name in re.findall(r'([A-Za-z_][A-Za-z0-9_]*) = _op_wrapper', data):
107+
if name in funcs:
108+
print(name)
109+
all_funcs.append(funcs[name])
110+
111+
112+
tree.body = all_funcs
113+
114+
with open('tilelang/language/tir/ir.pyi', 'w') as f:
115+
f.write(ast.unparse(tree))

testing/python/language/test_tilelang_language_frontend_v2.py

Lines changed: 57 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -145,62 +145,63 @@ def test_str_repr():
145145
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841
146146

147147

148-
def test_torch_eq():
149-
dtypes = [
150-
T.bool,
151-
T.short,
152-
T.int,
153-
T.long,
154-
T.half,
155-
T.float,
156-
T.long,
157-
T.int8,
158-
T.int16,
159-
T.int32,
160-
T.int64,
161-
T.uint8,
162-
T.uint16,
163-
T.uint32,
164-
T.uint64,
165-
T.float8_e4m3fn,
166-
T.float8_e4m3fnuz,
167-
T.float8_e5m2,
168-
T.float8_e5m2fnuz,
169-
T.float8_e8m0fnu,
170-
T.float16,
171-
T.bfloat16,
172-
T.float32,
173-
T.float64,
174-
]
175-
torch_dtypes = [
176-
torch.bool,
177-
torch.short,
178-
torch.int,
179-
torch.long,
180-
torch.half,
181-
torch.float,
182-
torch.long,
183-
torch.int8,
184-
torch.int16,
185-
torch.int32,
186-
torch.int64,
187-
torch.uint8,
188-
torch.uint16,
189-
torch.uint32,
190-
torch.uint64,
191-
torch.float8_e4m3fn,
192-
torch.float8_e4m3fnuz,
193-
torch.float8_e5m2,
194-
torch.float8_e5m2fnuz,
195-
torch.float8_e8m0fnu,
196-
torch.float16,
197-
torch.bfloat16,
198-
torch.float32,
199-
torch.float64,
200-
]
201-
for a, b in zip(dtypes, torch_dtypes):
202-
assert a == b, f"{a} and {b} are not equal"
203-
assert T.dtype(b) == a, "dtype conversion error"
148+
# not supported now
149+
# def test_torch_eq():
150+
# dtypes = [
151+
# T.bool,
152+
# T.short,
153+
# T.int,
154+
# T.long,
155+
# T.half,
156+
# T.float,
157+
# T.long,
158+
# T.int8,
159+
# T.int16,
160+
# T.int32,
161+
# T.int64,
162+
# T.uint8,
163+
# T.uint16,
164+
# T.uint32,
165+
# T.uint64,
166+
# T.float8_e4m3fn,
167+
# T.float8_e4m3fnuz,
168+
# T.float8_e5m2,
169+
# T.float8_e5m2fnuz,
170+
# T.float8_e8m0fnu,
171+
# T.float16,
172+
# T.bfloat16,
173+
# T.float32,
174+
# T.float64,
175+
# ]
176+
# torch_dtypes = [
177+
# torch.bool,
178+
# torch.short,
179+
# torch.int,
180+
# torch.long,
181+
# torch.half,
182+
# torch.float,
183+
# torch.long,
184+
# torch.int8,
185+
# torch.int16,
186+
# torch.int32,
187+
# torch.int64,
188+
# torch.uint8,
189+
# torch.uint16,
190+
# torch.uint32,
191+
# torch.uint64,
192+
# torch.float8_e4m3fn,
193+
# torch.float8_e4m3fnuz,
194+
# torch.float8_e5m2,
195+
# torch.float8_e5m2fnuz,
196+
# torch.float8_e8m0fnu,
197+
# torch.float16,
198+
# torch.bfloat16,
199+
# torch.float32,
200+
# torch.float64,
201+
# ]
202+
# for a, b in zip(dtypes, torch_dtypes):
203+
# assert a == b, f"{a} and {b} are not equal"
204+
# assert T.dtype(b) == a, "dtype conversion error"
204205

205206

206207
def test_var_assign():

triteo_linear.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import tilelang
2+
import torch
3+
import tilelang.language as T
4+
5+
# n = 2 ** 25
6+
B = 8
7+
t = 2**11
8+
D = 128
9+
k = torch.randn(B,t,D, dtype=torch.float32, device='cuda')
10+
s = torch.softmax(torch.randn(B,t,3, dtype=torch.float32, device='cuda'),dim=-1)
11+
12+
def shift_with_zeros(x, shift, dim):
13+
"""
14+
沿指定维度平移张量,移出去的部分用 0 填充
15+
x: 输入张量
16+
shift: 正数表示向后(高索引)移动,负数表示向前(低索引)移动
17+
dim: 平移的维度
18+
"""
19+
if shift == 0:
20+
return x
21+
# 记录张量形状
22+
zeros_shape = list(x.shape)
23+
zeros_shape[dim] = abs(shift)
24+
zeros = torch.zeros(zeros_shape, dtype=x.dtype, device=x.device)
25+
26+
if shift > 0:
27+
# 向后移动
28+
return torch.cat([zeros, x.narrow(dim, 0, x.shape[dim] - shift)], dim=dim)
29+
else:
30+
# 向前移动
31+
shift = -shift
32+
return torch.cat([x.narrow(dim, shift, x.shape[dim] - shift), zeros], dim=dim)
33+
34+
def make_first_recurrent(k, s):
35+
"""
36+
k: [b, h, t, d]
37+
s: [b, h, t, 3]
38+
非循环位移版本:torch.roll 改为 shift_with_zeros
39+
"""
40+
b, h, t, d = k.shape
41+
device = k.device
42+
dtype = k.dtype
43+
# 初始化 S(不含时间维度)
44+
S = torch.zeros((b, h, d, d), dtype=dtype, device=device)
45+
o = []
46+
for i in range(t):
47+
# 保存当前 time step 的 S[:, :, 0] (加一个时间维)
48+
o.append(S[:, :, 0].unsqueeze(2))
49+
# 左右平移(补零)
50+
S_left = shift_with_zeros(S, 1, dim=2) # j-1
51+
S_right = shift_with_zeros(S, -1, dim=2) # j+1
52+
# 取权重并广播
53+
w0 = s[:, :, i, 0].unsqueeze(-1).unsqueeze(-1) # [b,h,1,1]
54+
w1 = s[:, :, i, 1].unsqueeze(-1).unsqueeze(-1)
55+
w2 = s[:, :, i, 2].unsqueeze(-1).unsqueeze(-1)
56+
# 更新 S
57+
S = S_left * w0 + S * w1 + S_right * w2
58+
# 更新 S 的第 0 列
59+
S[:, :, 0] = S[:, :, 0] + w0.squeeze(-1) * k[:, :, i]
60+
return torch.cat(o, dim=2)
61+
block_size = 32
62+
num_block = t // block_size
63+
o_torch = torch.cat([ make_first_recurrent(k[:,i*block_size: (i+1)* block_size].unsqueeze(1),s[:,i*block_size: (i+1)* block_size].unsqueeze(1))for i in range(num_block)],dim=2).unsqueeze(1)
64+
65+
@tilelang.jit
66+
def inner_chunk_recurrent_fwd_init0(b,t,d,blk_t=block_size) -> tilelang.JITKernel:
67+
68+
@T.prim_func
69+
def inner_chunk_recurrent_fwd_init0_(
70+
S: T.Tensor((b, t//blk_t, d, d), 'float32'),
71+
k: T.Tensor((b, t, d), 'float32'),
72+
s: T.Tensor((b, t, 3), 'float32'),
73+
o: T.Tensor((b, t, d), 'float32'),
74+
):
75+
76+
with T.Kernel(b * d,T.ceildiv(t, blk_t)) as (i_bd, i_t):
77+
i_b = i_bd // d
78+
i_d = i_bd % d
79+
S_temp = T.alloc_fragment(d, 'float32')
80+
S_down = T.alloc_fragment(d, 'float32')
81+
S_up = T.alloc_fragment(d, 'float32')
82+
S_mid = T.alloc_fragment(d, 'float32')
83+
for i0_d in T.Parallel(d):
84+
S_temp[i0_d] = 0
85+
S_down[i0_d] = 0
86+
S_up[i0_d] = 0
87+
S_mid[i0_d] = 0
88+
for i0_t in T.serial(blk_t):
89+
t_local = i0_t*blk_t + i0_t
90+
#先存第一行也就是栈顶,到输出的o里面
91+
o[i_b,t_local,i_d] = S_temp[0]
92+
#再做三对角,实际上也就是相邻行的加权求和
93+
down = s[i_b,t_local,0]
94+
mid = s[i_b,t_local,1]
95+
up = s[i_b,t_local,2]
96+
for i0_d in T.Parallel(d-1):
97+
S_down[i0_d + 1] = S_temp[i0_d] * down
98+
for i0_d in T.Parallel(d-1):
99+
S_up[i0_d] = S_temp[i0_d + 1] * up
100+
for i0_d in T.Parallel(d):
101+
S_mid[i0_d] = S_temp[i0_d] * mid
102+
S_down[0] = 0
103+
S_up[d-1] = 0
104+
for i0_d in T.Parallel(d):
105+
S_temp[i0_d] += S_mid[i0_d]
106+
S_temp[i0_d] += S_down[i0_d]
107+
S_temp[i0_d] += S_up[i0_d]
108+
#往栈顶写入当前的k
109+
S_temp[0] += down * k[i_b,t_local,i_d]
110+
# 存储当前block最终的状态S,留作未来计算
111+
for i0_d in T.Parallel(d):
112+
S[i_b,i_t,i0_d,i_d] = S_temp[i0_d]
113+
return inner_chunk_recurrent_fwd_init0_
114+
115+
# 这个参数是可以灵活配置的
116+
for blk_t in [32,64,128]:
117+
print(f'---------------- {blk_t=} ----------------')
118+
kernel = inner_chunk_recurrent_fwd_init0(B, t, D, blk_t)
119+
120+
S = torch.empty(B,t // blk_t,D,D).to(k)
121+
o_tilelang = torch.empty_like(k)
122+
kernel(S,k,s,o_tilelang)
123+
if blk_t == 32:
124+
assert torch.all(o_torch == o_tilelang)
125+
with torch.profiler.profile() as prof:
126+
for _ in range(10):
127+
inner_chunk_recurrent_fwd_init0(B,t,D,blk_t)(S,k,s,o_tilelang)
128+
print(prof.key_averages().table())

0 commit comments

Comments
 (0)