Skip to content

Commit 93cb308

Browse files
hypercubestartyongwww
authored andcommitted
[EmitTE] EmitTE Symbolic Shape (apache#53)
1 parent 7f7eff8 commit 93cb308

File tree

5 files changed

+155
-11
lines changed

5 files changed

+155
-11
lines changed

python/tvm/relax/block_builder.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,29 @@ def _convert_te_arg_helper(arg):
148148
new_arg = _convert_te_arg_helper(te_args)
149149
return new_arg, te_args_list
150150

151-
def _check_te_args(self, args: List[tvm.te.Tensor]):
152-
"""check te arguments."""
153-
#TODO(hypercubestart, ziheng) support full dynamic shape in the future
154-
for x in args:
151+
def _check_te_args(self, args: List[tvm.te.Tensor], te_out: tvm.te.Tensor):
152+
"""check te arguments"""
153+
#TODO(hypercubestart, ziheng) support case where match_buffer doesn't bind to a variable
154+
tensors = args + [te_out]
155+
bound_vars = set()
156+
used_vars = set()
157+
158+
def _populate_used_vars(expr):
159+
if isinstance(expr, tvm.tir.Var):
160+
used_vars.add(expr)
161+
162+
for x in tensors:
155163
for s in x.shape:
156-
if not isinstance(s, (tir.Var, tir.IntImm)):
157-
raise ValueError("emit_te not support symbolic shape"
158-
"contains expression now: {}".format(x.shape))
164+
tvm.tir.stmt_functor.post_order_visit(s, _populate_used_vars)
165+
if isinstance(s, tir.Var):
166+
bound_vars.add(s)
167+
168+
diff = used_vars - bound_vars
169+
170+
if len(diff) != 0:
171+
# there are TIR variable in shape expressions that are not bound by match buffer
172+
raise ValueError("emit_te does not support TE functions with unbound tir.Vars: {}".format(diff))
173+
159174

160175
def function(self,
161176
params: Optional[Union[Var, Tuple, List[Var]]] = None,
@@ -280,12 +295,13 @@ def rx_func(x: Tensor[(n, m), "float32"], y: Tensor[(n, m), "float32"]) -> Tenso
280295
new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs)
281296

282297
te_args = te_arg_list + te_kwarg_list
283-
self._check_te_args(te_args)
284298

285299
# TODO(hypercubestart, ziheng) handle multiple output case
286300
te_out = func(*new_args, **new_kwargs)
287301
assert isinstance(te_out, tvm.te.tensor.Tensor), "only support te tensor as function output"
288302

303+
self._check_te_args(te_args, te_out)
304+
289305
inputs = [*te_args, te_out]
290306
tir_func = tvm.te.create_prim_func(inputs)
291307
func_name = _ffi_api.BlockBuilderGetUniqueName(self, func.__name__)

src/relax/transform/call_dps_rewrite.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class CallDPSMutator : public ExprMutator {
5252

5353
if (call->op == call_dps_op) {
5454
ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
55-
Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc");
55+
Var tensor = builder_->Emit(Call(alloc_tensor_op, {output_shape}), "alloc");
5656
Array<Expr> args;
5757
if (call->args[2].as<TupleNode>()) {
5858
args = Downcast<Tuple>(call->args[2])->fields;

src/relax/vm/executable.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ std::string ExecutableNode::Stats() const {
7474
}
7575
oss.seekp(-2, oss.cur);
7676
oss << "], ";
77+
} else if (it.IsObjectRef<ShapeTuple>()){
78+
ShapeTuple shape = it.operator ShapeTuple();
79+
oss << "shapetuple[";
80+
for (size_t i = 0; i < shape.size(); ++i) {
81+
oss << shape.at(i) << ", ";
82+
}
83+
oss.seekp(-2, oss.cur);
84+
oss << "], ";
7785
} else {
7886
try {
7987
DLDataType dtype = it.operator DLDataType();

tests/python/relax/test_vm.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
import numpy as np
2020
import tvm
2121
from tvm.relay import Call
22-
from tvm import relax, tir
22+
from tvm import relax, tir, te
2323
from tvm.runtime import container
2424
import numpy as np
2525

26+
from tvm.ir.base import assert_structural_equal
2627
import tvm.script
2728
from tvm.script import tir as T, relax as R
2829

@@ -425,6 +426,98 @@ def test_vm_emit_te_extern():
425426
expected = np.dot(data.asnumpy(), weight.asnumpy())
426427
np.testing.assert_allclose(expected, res.asnumpy(), rtol=1e-4, atol=1e-4)
427428

429+
def test_vm_emit_te_concat():
430+
# concatenate of two vectors of size (n,) and (m,)
431+
bb = relax.BlockBuilder()
432+
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
433+
type_anno = relax.DynTensorType(1, "float32")
434+
x = relax.Var("x", [n], type_anno)
435+
y = relax.Var("y", [m], type_anno)
436+
437+
def te_func(A, B):
438+
C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i], B[i-n]))
439+
return C
440+
441+
with bb.function([x, y], "rx_func"):
442+
x1 = bb.emit_te(te_func, x, y)
443+
bb.emit_func_output(x1)
444+
445+
mod = bb.get()
446+
447+
target = tvm.target.Target("llvm")
448+
target_host = tvm.target.Target("llvm")
449+
ex, lib = relax.vm.build(mod, target, target_host)
450+
451+
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
452+
inp = tvm.nd.array(np.random.rand(1, ).astype(np.float32))
453+
inp2 = tvm.nd.array(np.random.rand(2, ).astype(np.float32))
454+
res = vm["rx_func"](inp, inp2)
455+
456+
np.testing.assert_allclose(res.asnumpy(), np.append(inp.asnumpy(), inp2.asnumpy()))
457+
458+
def test_vm_emit_te_floor_symbolic_shape():
459+
bb = relax.BlockBuilder()
460+
n = tir.Var("n", "int64")
461+
type_anno = relax.DynTensorType(1, "float32")
462+
x = relax.Var("x", [n], type_anno)
463+
464+
def te_func(A):
465+
C = te.compute((tir.floordiv(n, 2),), lambda i: A[i] + 1)
466+
return C
467+
468+
with bb.function([x], "rx_func"):
469+
x1 = bb.emit_te(te_func, x)
470+
bb.emit_func_output(x1)
471+
472+
mod = bb.get()
473+
474+
target = tvm.target.Target("llvm")
475+
target_host = tvm.target.Target("llvm")
476+
ex, lib = relax.vm.build(mod, target, target_host)
477+
478+
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
479+
shape = (9, )
480+
inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
481+
res = vm["rx_func"](inp)
482+
483+
def expected_output():
484+
output_shape = (shape[0] // 2, )
485+
return inp.asnumpy()[:output_shape[0]] + 1
486+
487+
np.testing.assert_allclose(res.asnumpy(), expected_output())
488+
489+
def test_vm_relax_symbolic_shape():
490+
bb = relax.BlockBuilder()
491+
n = tir.Var("n", "int64")
492+
type_anno = relax.DynTensorType(1, "float32")
493+
x = relax.Var("x", [n], type_anno)
494+
y = relax.Var("y", [(n // 2) + 1], type_anno)
495+
496+
def te_func(A, B):
497+
C = te.compute((n, ), lambda i: A[i] + B[i // 2])
498+
return C
499+
500+
with bb.function([x, y], "rx_func"):
501+
x1 = bb.emit_te(te_func, x, y)
502+
bb.emit_func_output(x1)
503+
504+
mod = bb.get()
505+
506+
target = tvm.target.Target("llvm")
507+
target_host = tvm.target.Target("llvm")
508+
ex, lib = relax.vm.build(mod, target, target_host)
509+
510+
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
511+
shape1 = (5, )
512+
shape2 = (3, )
513+
inp = tvm.nd.array(np.random.rand(*shape1).astype(np.float32))
514+
inp2 = tvm.nd.array(np.random.rand(*shape2).astype(np.float32))
515+
res = vm["rx_func"](inp, inp2)
516+
517+
def expected_output():
518+
return inp.asnumpy() + np.repeat(inp2.asnumpy(), 2)[:5]
519+
520+
np.testing.assert_allclose(res.asnumpy(), expected_output())
428521

429522
if __name__ == "__main__":
430523
test_vm_execute()
@@ -443,3 +536,6 @@ def test_vm_emit_te_extern():
443536
test_vm_compile_e2e()
444537
test_vm_compile_e2e_func_param_with_shape()
445538
test_vm_emit_te_extern()
539+
test_vm_emit_te_concat()
540+
test_vm_emit_te_floor_symbolic_shape()
541+
test_vm_relax_symbolic_shape()

tests/python/unittest/test_te_create_primfunc.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
16-
# under the License.
16+
# under the License
1717
# pylint: disable=missing-function-docstring,missing-module-docstring
1818
import numpy as np
1919
import tvm
@@ -345,6 +345,29 @@ def test_data_dependent_access():
345345
tvm.testing.assert_allclose(a_np[b_np], c.numpy())
346346

347347

348+
def test_loop_var_datatype():
349+
def test_helper(dtype):
350+
n = te.var("n", dtype)
351+
A = te.placeholder((n,), name="A")
352+
B = te.placeholder((n,), name="B", dtype="int32")
353+
C = te.compute((n,), lambda i: A[i] + B[i])
354+
355+
func = te.create_prim_func([C, A, B])
356+
357+
assert func.body.block.body.loop_var.dtype == dtype
358+
359+
func = tvm.build(func)
360+
361+
a_np = np.random.uniform(size=(10,)).astype(A.dtype)
362+
b_np = np.random.uniform(size=(10,)).astype(B.dtype)
363+
c = tvm.nd.array(np.zeros(10, dtype=C.dtype))
364+
func(c, tvm.nd.array(a_np), tvm.nd.array(b_np))
365+
tvm.testing.assert_allclose(a_np + b_np, c.numpy())
366+
367+
test_helper("int32")
368+
test_helper("int64")
369+
370+
348371
def test_select_simplify():
349372
placeholder = te.placeholder([1, 128, 10, 10, 4], dtype="float32")
350373
tensor = topi.nn.adaptive_pool(placeholder, [1, 1], "avg", "NCHW4c")
@@ -568,3 +591,4 @@ def expected(
568591
test_argmax_val_idx()
569592
test_int64_indices()
570593
test_zero_dim_add()
594+
test_loop_var_datatype()

0 commit comments

Comments
 (0)