Skip to content

Commit dbaeccf

Browse files
authored
Match decode + NT-GeMV + [ewise] pattern (#2)
This PR uses FuseTIRByPattern to match the decode + NT-GeMV + optionally a trailing element-wise TIR function. E2E verified locally. The next step is to turn off NT-matmul and update the quantization encoding/decoding accordingly so that the quantization encoding func transposes the weights from T to N, and also update this pattern match function accordingly.
1 parent 8a57430 commit dbaeccf

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

build.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def mod_transform_before_build(
117117
mod = relax.transform.FuseTIR()(mod)
118118

119119
mod = web_llm.transform.GroupQuantize(group_size=32, sym=False)(mod)
120+
mod = web_llm.transform.FuseDecodeNTMatmulEwise()(mod)
120121
mod = relax.transform.DeadCodeElimination(model_names)(mod)
121122
mod = relax.transform.LiftTransformParams()(mod)
122123
mod_transform, mod_deploy = utils.split_transform_deploy_mod(mod, model_names)

web_llm/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .dispatch_tir_operator import DispatchTIROperator
22
from .quantization import GroupQuantize
33
from .transpose_matmul import FuseTransposeMatmul
4+
from .decode_NT_matmul_ewise import FuseDecodeNTMatmulEwise
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import tvm
2+
from tvm import IRModule
3+
from tvm import relax, tir
4+
from tvm.relax.dpl.pattern import is_op, wildcard
5+
from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern
6+
7+
8+
def check_x_1dim(ctx: relax.transform.PatternCheckContext) -> bool:
9+
x = ctx.annotated_expr["x"]
10+
n = x.struct_info.shape[-2]
11+
return isinstance(n, tir.IntImm) and n.value == 1
12+
13+
14+
def check_decoding(ctx: relax.transform.PatternCheckContext) -> bool:
15+
call = ctx.annotated_expr["w"]
16+
gv = call.args[0]
17+
return gv.name_hint.startswith("decode")
18+
19+
20+
def check_NT_matmul(ctx: relax.transform.PatternCheckContext) -> bool:
21+
call = ctx.annotated_expr["NT_matmul"]
22+
gv = call.args[0]
23+
return gv.name_hint.startswith("NT_matmul") or gv.name_hint.startswith("fused_NT_matmul")
24+
25+
26+
def pattern_check(ctx: relax.transform.PatternCheckContext) -> bool:
27+
return check_x_1dim(ctx) and check_decoding(ctx) and check_NT_matmul(ctx)
28+
29+
30+
def decode_NT_matmul_pattern():
31+
w_scaled = wildcard()
32+
scale_min = wildcard()
33+
x = wildcard()
34+
w = is_op("relax.call_tir")(
35+
GlobalVarPattern(), TuplePattern([w_scaled, scale_min]), add_constraint=False
36+
)
37+
NT_matmul = is_op("relax.call_tir")(
38+
GlobalVarPattern(), TuplePattern([x, w]), add_constraint=False
39+
)
40+
41+
annotations = {
42+
"NT_matmul": NT_matmul,
43+
"w": w,
44+
"x": x,
45+
"w_scaled": w_scaled,
46+
"scale_min": scale_min,
47+
}
48+
49+
return NT_matmul, annotations, pattern_check
50+
51+
52+
def decode_NT_matmul_ewise_pattern():
53+
w_scaled = wildcard()
54+
scale_min = wildcard()
55+
x = wildcard()
56+
y = wildcard()
57+
w = is_op("relax.call_tir")(
58+
GlobalVarPattern(), TuplePattern([w_scaled, scale_min]), add_constraint=False
59+
)
60+
NT_matmul_ewise = is_op("relax.call_tir")(
61+
GlobalVarPattern(), TuplePattern([x, w, y]), add_constraint=False
62+
)
63+
64+
annotations = {
65+
"NT_matmul": NT_matmul_ewise,
66+
"w": w,
67+
"x": x,
68+
"w_scaled": w_scaled,
69+
"scale_min": scale_min,
70+
}
71+
72+
return NT_matmul_ewise, annotations, pattern_check
73+
74+
75+
@tvm.transform.module_pass(opt_level=0, name="FuseDecodeNTMatmulEwise")
76+
class FuseDecodeNTMatmulEwise:
77+
def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule:
78+
mod = relax.transform.FuseOpsByPattern([("decode_NT_matmul", *decode_NT_matmul_pattern())])(
79+
mod
80+
)
81+
mod = relax.transform.FuseOpsByPattern(
82+
[("decode_NT_matmul_ewise", *decode_NT_matmul_ewise_pattern())]
83+
)(mod)
84+
mod = relax.transform.FuseTIR()(mod)
85+
86+
return mod

0 commit comments

Comments
 (0)