-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Description
Hi,
I've rewritten the matmul algorithm as follows by peeling the first iteration of the summation. When I run the following code,
import tvm
M = 600
N = 650
P = 700
# Algorithm
k = tvm.reduce_axis((1, N), 'k') # note the starting value of 1 (instead of 0)
A = tvm.placeholder((M, N), name = 'A')
B = tvm.placeholder((N, P), name = 'B')
C = tvm.compute(
(M, P),
lambda x, y: A[x,0]*B[0,y] + tvm.sum(A[x, k] * B[k, y], axis = k),
name = 'C')
# Default schedule
s = tvm.create_schedule(C.op)
print(tvm.lower(s, [A, B, C], simple_mode=True))
I get the following error message. Am I doing anything wrong? Is there a way to specify the matmul algorithm while explicitly peeling the first iteration of the sum reduction?
The reason I want do this is that I don't want tvm to generate zero-initialization code for C. I want C[x,y] to be initialized to A[x,0]*B[0,k] and then to generate the summation code.
Thanks!
[17:01:51] /hwapwork/s00327669/software/tvm/dmlc-core/include/dmlc/logging.h:308: [17:01:51] src/lang/ir.cc:23: Reduce do not work with old Visitor, use IRFunctor style visitor
Stack trace returned 10 entries:
[bt] (0) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x39) [0x7faae5e37bb9]
[bt] (1) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::ExprNode<tvm::ir::Reduce>::accept(Halide::Internal::IRVisitor*, Halide::Expr const&) const+0x4f) [0x7faae5ed7c9f]
[bt] (2) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::IRMutator::mutate(Halide::Expr)+0x1a) [0x7faae6193cba]
[bt] (3) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::Simplify::visit(Halide::Internal::Add const*, Halide::Expr const&)+0xdf) [0x7faae61319af]
[bt] (4) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::IRMutator::mutate(Halide::Expr)+0x1a) [0x7faae6193cba]
[bt] (5) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::Simplify::visit(Halide::Internal::Store const*, Halide::Internal::Stmt const&)+0xca) [0x7faae6114f1a]
[bt] (6) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::IRMutator::mutate(Halide::Internal::Stmt)+0x1a) [0x7faae6193d2a]
[bt] (7) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::Simplify::visit(Halide::Internal::For const*, Halide::Internal::Stmt const&)+0x15a) [0x7faae6163aca]
[bt] (8) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::IRMutator::mutate(Halide::Internal::Stmt)+0x1a) [0x7faae6193d2a]
[bt] (9) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::Simplify::visit(Halide::Internal::For const*, Halide::Internal::Stmt const&)+0x15a) [0x7faae6163aca]
Traceback (most recent call last):
File "test_sum.py", line 61, in <module>
print(tvm.lower(s, [A, B, C], simple_mode=True))
File "/home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/build_module.py", line 228, in lower
stmt = ir_pass.Simplify(stmt)
File "/home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/_ffi/function.py", line 255, in my_api_func
return flocal(*args)
File "/home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/_ffi/_ctypes/function.py", line 183, in __call__
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
File "/home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/_ffi/base.py", line 62, in check_call
raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [17:01:51] src/lang/ir.cc:23: Reduce do not work with old Visitor, use IRFunctor style visitor
Stack trace returned 10 entries:
[bt] (0) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x39) [0x7faae5e37bb9]
[bt] (1) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::ExprNode<tvm::ir::Reduce>::accept(Halide::Internal::IRVisitor*, Halide::Expr const&) const+0x4f) [0x7faae5ed7c9f]
[bt] (2) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::IRMutator::mutate(Halide::Expr)+0x1a) [0x7faae6193cba]
[bt] (3) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::Simplify::visit(Halide::Internal::Add const*, Halide::Expr const&)+0xdf) [0x7faae61319af]
[bt] (4) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::IRMutator::mutate(Halide::Expr)+0x1a) [0x7faae6193cba]
[bt] (5) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::Simplify::visit(Halide::Internal::Store const*, Halide::Internal::Stmt const&)+0xca) [0x7faae6114f1a]
[bt] (6) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::IRMutator::mutate(Halide::Internal::Stmt)+0x1a) [0x7faae6193d2a]
[bt] (7) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::Simplify::visit(Halide::Internal::For const*, Halide::Internal::Stmt const&)+0x15a) [0x7faae6163aca]
[bt] (8) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::IRMutator::mutate(Halide::Internal::Stmt)+0x1a) [0x7faae6193d2a]
[bt] (9) /home/s00327669/.local/lib/python2.7/site-packages/tvm-0.1.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(Halide::Internal::Simplify::visit(Halide::Internal::For const*, Halide::Internal::Stmt const&)+0x15a) [0x7faae6163aca]
Metadata
Metadata
Assignees
Labels
No labels