diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9f9ed1c075cd..3781107eeee1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2829,15 +2829,18 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, } reporter->Assign(types[1], TupleType(Array(fields))); } else { - auto indices = Downcast>(param->indices_or_sections); + Array indices; + for (auto i : Downcast>(param->indices_or_sections)) { + indices.push_back(IntImm(DataType::Int(32), i.as()->value)); + } auto begin = IndexExpr(tir::make_zero(DataType::Int(32))); std::vector fields; for (unsigned int i = 0; i < indices.size(); ++i) { - ICHECK(reporter->Assert(Downcast(indices[i]) > begin)) + ICHECK(reporter->Assert(indices[i] > begin)) << "indices_or_sections need to be a sorted ascending list"; std::vector oshape(data->shape.begin(), data->shape.end()); - oshape[axis] = Downcast(indices[i]) - begin; - begin = Downcast(indices[i]); + oshape[axis] = indices[i] - begin; + begin = indices[i]; auto vec_type = TensorType(oshape, data->dtype); fields.push_back(vec_type); } diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 41a866a0a034..e0b95fe7fbf7 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -516,6 +516,21 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None): ), axis=1, ) + verify_split( + (d1, d2, d3, d4), + tuple(np.array([2, 4, 7]).astype(np.int64)), + relay.ty.TupleType( + tvm.runtime.convert( + [ + relay.ty.TensorType((d1, 2, d3, d4), "float32"), + relay.ty.TensorType((d1, 2, d3, d4), "float32"), + relay.ty.TensorType((d1, 3, d3, d4), "float32"), + relay.ty.TensorType((d1, (d2 - 7), d3, d4), "float32"), + ] + ) + ), + axis=1, + ) def test_full_infer_type():