Skip to content

Commit 7deebc6

Browse files
[BUG] DataType Bug In SplitRel (#8899)
* [BUG] DataType Bug In SplitRel * Add Test Case
1 parent 707c4e0 commit 7deebc6

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

src/relay/op/tensor/transform.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2829,15 +2829,18 @@ bool SplitRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
28292829
}
28302830
reporter->Assign(types[1], TupleType(Array<Type>(fields)));
28312831
} else {
2832-
auto indices = Downcast<Array<ObjectRef>>(param->indices_or_sections);
2832+
Array<IndexExpr> indices;
2833+
for (auto i : Downcast<Array<Integer>>(param->indices_or_sections)) {
2834+
indices.push_back(IntImm(DataType::Int(32), i.as<IntImmNode>()->value));
2835+
}
28332836
auto begin = IndexExpr(tir::make_zero(DataType::Int(32)));
28342837
std::vector<Type> fields;
28352838
for (unsigned int i = 0; i < indices.size(); ++i) {
2836-
ICHECK(reporter->Assert(Downcast<IndexExpr>(indices[i]) > begin))
2839+
ICHECK(reporter->Assert(indices[i] > begin))
28372840
<< "indices_or_sections need to be a sorted ascending list";
28382841
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
2839-
oshape[axis] = Downcast<IndexExpr>(indices[i]) - begin;
2840-
begin = Downcast<IndexExpr>(indices[i]);
2842+
oshape[axis] = indices[i] - begin;
2843+
begin = indices[i];
28412844
auto vec_type = TensorType(oshape, data->dtype);
28422845
fields.push_back(vec_type);
28432846
}

tests/python/relay/test_op_level3.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,21 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None):
516516
),
517517
axis=1,
518518
)
519+
verify_split(
520+
(d1, d2, d3, d4),
521+
tuple(np.array([2, 4, 7]).astype(np.int64)),
522+
relay.ty.TupleType(
523+
tvm.runtime.convert(
524+
[
525+
relay.ty.TensorType((d1, 2, d3, d4), "float32"),
526+
relay.ty.TensorType((d1, 2, d3, d4), "float32"),
527+
relay.ty.TensorType((d1, 3, d3, d4), "float32"),
528+
relay.ty.TensorType((d1, (d2 - 7), d3, d4), "float32"),
529+
]
530+
)
531+
),
532+
axis=1,
533+
)
519534

520535

521536
def test_full_infer_type():

0 commit comments

Comments
 (0)