Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug][Relax] Shape Mismatch for function argument #17310

Open
Cookiee235 opened this issue Aug 28, 2024 · 2 comments
Open

[Bug][Relax] Shape Mismatch for function argument #17310

Cookiee235 opened this issue Aug 28, 2024 · 2 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Cookiee235
Copy link
Contributor

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/res_ut/res_executions/30_test.py", line 50, in <module>
    ex = relax.build(mod, target='llvm')
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/relax/vm_build.py", line 335, in build
    mod = pipeline(mod)
          ^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/software/tvm/python/tvm/relax/pipeline.py", line 101, in _pipeline
    mod = seq(mod)
          ^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  38: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  37: tvm::transform::Pass::operator()(tvm::IRModule) const
  36: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  35: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  34: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  33: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  32: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
  31: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  30: tvm::relax::CallTIRMutator::Run()
  29: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  28: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  27: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  26: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
  25: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
  24: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  23: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  22: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  21: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
  20: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
  19: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
  18: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
  17: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
  16: _ZZN3tvm5relax11ExprMutator22InitVisitBindingVTabl
  15: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::CallNode const*)
  14: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  13: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  12: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  11: tvm::relax::CallTIRMutator::VisitExpr_(tvm::relax::CallNode const*)
  10: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, tvm::runtime::String)
  9: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, bool, tvm::runtime::String)
  8: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
  7: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  6: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  5: non-virtual thunk to tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
  4: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
  3: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
  2: tvm::relax::DeriveCallRetStructInfo(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::arith::Analyzer*)
  1: tvm::relax::CallRetStructInfoDeriver::Derive(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
  0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
  File "/software/tvm/src/relax/ir/block_builder.cc", line 159
TVMError: Argument 0 type mismatch: expected R.Tensor((64, 64, 56, 56), dtype="float32"), given R.Tensor((1, 64, 56, 56), dtype="float32")

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def conv2d(data: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"), weight1: T.Buffer((T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((T.int64(1), T.int64(64), T.int64(58), T.int64(58)))
        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(64), T.int64(58), T.int64(58)):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(data[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(57) and T.int64(1) <= v_i3 and v_i3 < T.int64(57), data[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float32(0))
        for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(64), T.int64(3), T.int64(3)):
            with T.block("conv2d_nchw"):
                v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
                T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], weight1[v_ff, v_rc, v_ry, v_rx])
                T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
                with T.init():
                    conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0)
                conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * weight1[v_ff, v_rc, v_ry, v_rx]

    @T.prim_func
    def relu(data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64, 64, 56, 56), "float32")):
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56):
            with T.block("root"):
                i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(data[i, j, k, l])
                T.writes(out[i, j, k, l])
                out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0))

    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((64, 64, 56, 56), dtype="float32"):
        cls = Module
        with R.dataflow():
            conv1 = R.call_tir(cls.conv2d, (data, weight1), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
            relu1 = R.call_tir(cls.relu, (conv1,), out_sinfo=R.Tensor((64, 64, 56, 56), dtype="float32"))
            R.output(relu1)
        return relu1

mod = Module
mod.show()
ex = relax.build(mod, target='llvm')

The given Relax IR passed the IR validity checking but threw a crash when I built it. Could you help me review it? Thanks a lot!

CC @Lunderberg @junrushao

@Cookiee235 Cookiee235 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Aug 28, 2024
@xhmelon
Copy link
Contributor

xhmelon commented Sep 18, 2024

Hi @Cookiee235 ,
The error is caused by a mismatch between the output shape of conv2d and the input shape of relu, which are (1, 64, 56, 56) and (64, 64, 56, 56), respectively. I changed the shape of relu from (64, 64, 56, 56) to (1, 64, 56, 56) and it is built successfully.

@Cookiee235
Copy link
Contributor Author

@xhmelon Thanks for your investigation. Indeed, the Realx IR is invalid and the crash message also gives the correct warning. However, the above Relax IR passes the verify_well_formed validation and lets us mistakenly consider the Relax IR valid! It will be better if we catch the exception early (i.e., crash in the mod = Module statement)!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

2 participants