Skip to content

Commit 3fc8133

Browse files
YuchenJinyongwww
authored andcommitted
Update Shape lowering pass (apache#38)
* Update shape lowering pass. * Rebase.
1 parent 72d43d5 commit 3fc8133

File tree

3 files changed

+176
-36
lines changed

3 files changed

+176
-36
lines changed

src/relax/backend/vm/vm_shape_lower.cc

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,38 @@
3232
namespace tvm {
3333
namespace relax {
3434

35+
/*!
36+
* \brief Visitor to apply a function to every Expr it visits. Also applies the function
37+
* to the shape field of the var definition site if the var's shape is a ShapeExpr.
38+
*/
39+
class ExprApplyVisitWithShape : public ExprVisitor {
40+
public:
41+
explicit ExprApplyVisitWithShape(std::function<void(const Expr&)> f) : f_(f) {}
42+
43+
void VisitVarDef(const Var& var) {
44+
if (var.as<DataflowVarNode>()) {
45+
this->VisitExpr(Downcast<DataflowVar>(var));
46+
} else {
47+
this->VisitExpr(var);
48+
}
49+
if (var->shape_.operator bool() && var->shape_.value().as<ShapeExprNode>()) {
50+
f_(Downcast<ShapeExpr>(var->shape_.value()));
51+
}
52+
}
53+
54+
void VisitExpr(const Expr& e) final {
55+
ExprVisitor::VisitExpr(e);
56+
f_(e);
57+
}
58+
59+
private:
60+
std::function<void(const Expr&)> f_;
61+
};
62+
63+
void PostOrderVisitWithShape(const Expr& e, std::function<void(const Expr&)> fvisit) {
64+
ExprApplyVisitWithShape(fvisit).VisitExpr(e);
65+
}
66+
3567
class VMShapeLowerMutator : public ExprMutator {
3668
public:
3769
static DataType ShapeDType() { return DataType::Int(64); };
@@ -58,18 +90,11 @@ class VMShapeLowerMutator : public ExprMutator {
5890
}
5991

6092
void VisitBinding_(const MatchShapeNode* binding) override {
61-
Expr shape = ExprMutator::VisitExpr(binding->value);
62-
static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape");
63-
auto store_shape_attr = make_object<ShapeHeapAttrs>();
64-
65-
Array<PrimExpr> pattern = binding->pattern;
66-
Array<Integer> indices;
67-
for (size_t i = 0; i < pattern.size(); ++i) {
68-
int idx = expr2slot_.at(pattern[i]);
69-
indices.push_back(idx);
70-
}
71-
store_shape_attr->indices = indices;
72-
builder_->Emit(Call(store_shape_op, {shape, shape_heap_}, Attrs(store_shape_attr)), "gv");
93+
Expr value = ExprMutator::VisitExpr(binding->value);
94+
95+
// TODO(@yuchen): match_shape overloaded semantic: value is ShapeType
96+
Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {value}), "sh");
97+
StoreShape(shape, binding->pattern);
7398
}
7499

75100
Expr VisitExpr_(const ShapeExprNode* node) override {
@@ -97,16 +122,18 @@ class VMShapeLowerMutator : public ExprMutator {
97122
}
98123

99124
Expr VisitExpr_(const FunctionNode* node) override {
125+
builder_->BeginBindingBlock();
126+
builder_->Emit(VarBinding(
127+
shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})})));
100128
Array<Var> params;
101129
for (Var param : node->params) {
102130
params.push_back(this->VisitVarDef(param));
131+
if (param->shape_.operator bool() && param->shape_.value().as<ShapeExprNode>()) {
132+
Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {param}), "sh");
133+
StoreShape(shape, Downcast<ShapeExpr>(param->shape_.value())->values);
134+
}
103135
}
104136
Type ret_type = this->VisitType(node->ret_type);
105-
106-
builder_->BeginBindingBlock();
107-
builder_->Emit(VarBinding(
108-
shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})})));
109-
110137
Expr new_body = this->VisitExpr(node->body);
111138

112139
Array<BindingBlock> blocks;
@@ -174,10 +201,24 @@ class VMShapeLowerMutator : public ExprMutator {
174201
}
175202
}
176203
};
177-
PostOrderVisit(expr, func);
204+
PostOrderVisitWithShape(expr, func);
178205
return ret;
179206
}
180207

208+
/*! \brief Store symbolic shape into indices of the VM shape heap. */
209+
void StoreShape(Expr shape, Array<PrimExpr> pattern) {
210+
static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape");
211+
auto store_shape_attr = make_object<ShapeHeapAttrs>();
212+
213+
Array<Integer> indices;
214+
for (size_t i = 0; i < pattern.size(); ++i) {
215+
int idx = expr2slot_.at(pattern[i]);
216+
indices.push_back(idx);
217+
}
218+
store_shape_attr->indices = indices;
219+
builder_->Emit(Call(store_shape_op, {shape, shape_heap_}, Attrs(store_shape_attr)), "gv");
220+
}
221+
181222
bool IsConstantShape(ShapeExpr shape) const {
182223
for (PrimExpr e : shape->values) {
183224
if (!e->IsInstance<IntImmNode>()) {

tests/python/relax/test_transform.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tvm.ir.module import IRModule
2424

2525
import tvm.script
26-
from tvm.script import relax as R
26+
from tvm.script import tir as T, relax as R
2727

2828

2929
def test_fma_rewrite():
@@ -179,8 +179,7 @@ def test_vm_shape_lowering():
179179
class TestVMShapeLower:
180180
@R.function
181181
def foo(x: Tensor[_, "float32"]) -> Shape:
182-
sh = relax.call_packed("vm.builtin.shape_of", x)
183-
relax.match_shape(sh, (n, m))
182+
relax.match_shape(x, (n, m))
184183
return (n * 2, m * 3)
185184

186185
mod = TestVMShapeLower
@@ -196,6 +195,7 @@ def foo(x: Tensor[_, "float32"]) -> Shape:
196195
s1 = func.body.blocks[0].bindings[0].value
197196
assert isinstance(s1.op, relax.ExternFunc)
198197
assert s1.op.global_symbol == "vm.builtin.alloc_shape_heap"
198+
assert s1.args[0].values[0] == 4
199199
s2 = func.body.blocks[1].bindings[0].value
200200
assert isinstance(s2.op, relax.ExternFunc)
201201
assert s2.op.global_symbol == "vm.builtin.shape_of"
@@ -209,6 +209,65 @@ def foo(x: Tensor[_, "float32"]) -> Shape:
209209
assert isinstance(s5, tvm.relay.Call)
210210
assert s5.op.name == "relax.vm.builtin.load_shape"
211211

212+
213+
def test_vm_shape_lowering_func_param_with_shape():
214+
src = """@tvm.script.ir_module
215+
class InputModule:
216+
@T.prim_func
217+
def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
218+
T.func_attr({"global_symbol": "tir_matmul"})
219+
m = T.var("int32")
220+
n = T.var("int32")
221+
k = T.var("int32")
222+
A = T.match_buffer(x, (m,n))
223+
B = T.match_buffer(y, (n,k))
224+
C = T.match_buffer(z, (m,k))
225+
226+
for i, j, k in T.grid(m, k, n):
227+
with T.block("matmul"):
228+
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
229+
with T.init():
230+
C[vi, vj] = T.float32(0)
231+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
232+
@R.function
233+
def foo(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor:
234+
gv0 = R.call_dps((m, k), tir_matmul, (x, w))
235+
return gv0
236+
"""
237+
mod = tvm.script.relax.parser.from_source(src)
238+
239+
# after vm shape lowering
240+
new_mod = relax.transform.VMShapeLower()(mod)
241+
242+
assert isinstance(new_mod, tvm.IRModule)
243+
assert isinstance(new_mod["shape_func"], tvm.tir.function.PrimFunc)
244+
assert isinstance(new_mod["tir_matmul"], tvm.tir.function.PrimFunc)
245+
func = new_mod["foo"]
246+
assert isinstance(func, tvm.relax.expr.Function)
247+
248+
x, w = func.params
249+
s1 = func.body.blocks[0].bindings[0].value
250+
assert isinstance(s1.op, relax.ExternFunc)
251+
assert s1.op.global_symbol == "vm.builtin.alloc_shape_heap"
252+
assert s1.args[0].values[0] == 3
253+
254+
s2 = func.body.blocks[0].bindings[1].value
255+
assert isinstance(s2.op, relax.ExternFunc)
256+
assert s2.op.global_symbol == "vm.builtin.shape_of"
257+
assert s2.args[0] == x
258+
s3 = func.body.blocks[0].bindings[2].value
259+
assert isinstance(s3, tvm.relay.Call)
260+
assert s3.op.name == "relax.vm.builtin.store_shape"
261+
262+
s4 = func.body.blocks[0].bindings[3].value
263+
assert isinstance(s4.op, relax.ExternFunc)
264+
assert s4.op.global_symbol == "vm.builtin.shape_of"
265+
assert s4.args[0] == w
266+
s5 = func.body.blocks[0].bindings[2].value
267+
assert isinstance(s5, tvm.relay.Call)
268+
assert s5.op.name == "relax.vm.builtin.store_shape"
269+
270+
212271
def test_to_anf():
213272
x = relax.Var("x", type_annotation=relax.DynTensorType())
214273
gv = relax.op.add(x, x)
@@ -241,4 +300,5 @@ def f(x: Tensor[_, "float32"]):
241300
test_call_dps_rewrite()
242301
test_vm_memory_lower()
243302
test_vm_shape_lowering()
303+
test_vm_shape_lowering_func_param_with_shape()
244304
test_to_anf()

tests/python/relax/test_vm.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_vm_compile_stage0():
232232
class TestVMCompileStage0:
233233
@R.function
234234
def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]):
235-
z = relax.call_packed("test.vm.identity", x, y)
235+
z = R.call_packed("test.vm.identity", x, y)
236236
return y
237237

238238
mod = TestVMCompileStage0
@@ -272,13 +272,13 @@ def shape_func0(heap: T.handle) -> None:
272272
273273
@R.function
274274
def foo(x: Tensor[_, "float32"]) -> Shape:
275-
shape_heap: Tensor[(4,), "int64"] = relax.call_packed(
275+
shape_heap: Tensor[(4,), "int64"] = R.call_packed(
276276
"vm.builtin.alloc_shape_heap", (4,)
277277
)
278-
gv0 = relax.call_packed("vm.builtin.shape_of", x)
279-
gv1 = relax.call_packed("vm.builtin.store_shape", gv0, shape_heap, (0, 1))
278+
gv0 = R.call_packed("vm.builtin.shape_of", x)
279+
gv1 = R.call_packed("vm.builtin.store_shape", gv0, shape_heap, (0, 1))
280280
gv2 = shape_func0(shape_heap)
281-
gv3 = relax.call_packed("vm.builtin.load_shape", shape_heap, (2, 3))
281+
gv3 = R.call_packed("vm.builtin.load_shape", shape_heap, (2, 3))
282282
return gv3
283283
"""
284284

@@ -301,8 +301,7 @@ def test_vm_compile_stage2():
301301
class TestVMCompileStage2:
302302
@R.function
303303
def foo(x: Tensor[_, "float32"]) -> Shape:
304-
sh = relax.call_packed("vm.builtin.shape_of", x)
305-
relax.match_shape(sh, (n, m))
304+
R.match_shape(x, (n, m))
306305
return (n * 2, m * 3)
307306

308307
mod = TestVMCompileStage2
@@ -323,9 +322,9 @@ def test_vm_compile_stage3():
323322
class TestVMCompileStage3:
324323
@R.function
325324
def foo(x: Tensor[(32, 16), "float32"]) -> Tensor:
326-
with relax.dataflow():
327-
y = relax.call_dps((32, 16), "test.vm.identity", (x))
328-
relax.output(y)
325+
with R.dataflow():
326+
y = R.call_dps((32, 16), "test.vm.identity", (x))
327+
R.output(y)
329328
return y
330329

331330
mod = TestVMCompileStage3
@@ -345,11 +344,10 @@ def test_vm_compile_e2e():
345344
class TestVMCompileE2E:
346345
@R.function
347346
def foo(x: Tensor[_, "float32"]) -> Tensor:
348-
with relax.dataflow():
349-
sh = relax.call_packed("vm.builtin.shape_of", x)
350-
x0 = relax.match_shape(sh, (n, m))
351-
y = relax.call_dps((n, m * 2), "test.vm.tile", (x))
352-
relax.output(y)
347+
with R.dataflow():
348+
R.match_shape(x, (n, m))
349+
y = R.call_dps((n, m * 2), "test.vm.tile", (x))
350+
R.output(y)
353351
return y
354352

355353
mod = TestVMCompileE2E
@@ -364,6 +362,46 @@ def foo(x: Tensor[_, "float32"]) -> Tensor:
364362
res = vm["foo"](inp)
365363
np.testing.assert_allclose(np.tile(inp.asnumpy(), (1, 2)), res.asnumpy())
366364

365+
def test_vm_compile_e2e_func_param_with_shape():
366+
src = """@tvm.script.ir_module
367+
class TestVMCompileE2E2:
368+
@T.prim_func
369+
def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
370+
T.func_attr({"global_symbol": "tir_matmul"})
371+
m = T.var("int32")
372+
n = T.var("int32")
373+
k = T.var("int32")
374+
A = T.match_buffer(x, (m,n))
375+
B = T.match_buffer(y, (n,k))
376+
C = T.match_buffer(z, (m,k))
377+
378+
for i, j, k in T.grid(m, k, n):
379+
with T.block("matmul"):
380+
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
381+
with T.init():
382+
C[vi, vj] = T.float32(0)
383+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
384+
385+
@R.function
386+
def func(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor:
387+
gv0 = R.call_dps((m, k), tir_matmul, (x, w))
388+
return gv0
389+
"""
390+
391+
mod = tvm.script.relax.parser.from_source(src)
392+
393+
target = tvm.target.Target("llvm")
394+
target_host = tvm.target.Target("llvm")
395+
ex, lib = relax.vm.build(mod, target, target_host)
396+
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
397+
398+
import numpy as np
399+
data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
400+
weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
401+
res = vm["func"](data, weight)
402+
expected = np.dot(data.asnumpy(), weight.asnumpy())
403+
np.testing.assert_allclose(expected, res.asnumpy(), rtol=1e-4, atol=1e-4)
404+
367405

368406
if __name__ == "__main__":
369407
test_vm_execute()
@@ -380,3 +418,4 @@ def foo(x: Tensor[_, "float32"]) -> Tensor:
380418
test_vm_compile_stage2()
381419
test_vm_compile_stage3()
382420
test_vm_compile_e2e()
421+
test_vm_compile_e2e_func_param_with_shape()

0 commit comments

Comments
 (0)