Skip to content

Commit 3e00253

Browse files
author
Siyuan Feng
authored
[TIR] Fix Primitive Rfactor DType (#15413)
The rfactor primitive will create/rewrite two blocks, together with the block read/write regions. However, the generated read/write region extents are not valid when it's a int64 index. This commit fixes the issue.
1 parent 22ec541 commit 3e00253

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

src/tir/schedule/primitive/reduction.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,9 @@ class RFactorBlockCreator : public BaseBlockCreator {
912912
write_regions_.reserve(old_block->writes.size());
913913
for (const BufferRegion& write_region : old_block->writes) {
914914
Array<Range> region = write_region->region;
915-
region.insert(region.begin() + factor_axis_, Range::FromMinExtent(additional_iter_->var, 1));
915+
region.insert(region.begin() + factor_axis_,
916+
Range::FromMinExtent(additional_iter_->var,
917+
make_const(additional_iter_->var.dtype(), 1)));
916918
Optional<Buffer> rf_buffer = buffer_map.Get(write_region->buffer);
917919
ICHECK(rf_buffer.defined());
918920
write_regions_.push_back(BufferRegion(rf_buffer.value(), Substitute(region, var_map_)));
@@ -1005,7 +1007,7 @@ class WriteBackBlockCreator : public BaseBlockCreator {
10051007
Array<Range> region;
10061008
region.reserve(buf_load->indices.size());
10071009
for (const PrimExpr& index : buf_load->indices) {
1008-
region.push_back(Range::FromMinExtent(index, 1));
1010+
region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 1)));
10091011
}
10101012
buf_regions.push_back(BufferRegion(buf_load->buffer, std::move(region)));
10111013
}

tests/python/unittest/test_tir_schedule_rfactor.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
# pylint: disable=missing-function-docstring,missing-module-docstring
1818
import pytest
19+
1920
import tvm
2021
import tvm.testing
2122
from tvm import te, tir, topi
@@ -1643,5 +1644,61 @@ def test_reduction_rfactor_topi_argmin():
16431644
verify_trace_roundtrip(s, mod=argmin_topi)
16441645

16451646

1647+
def test_reduction_rfactor_int64():
1648+
# fmt: off
1649+
@T.prim_func
1650+
def before(
1651+
A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
1652+
B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
1653+
C: T.Buffer((T.int64(128), T.int64(128)), "float32"),
1654+
):
1655+
for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(
1656+
T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)
1657+
):
1658+
with T.block("update"):
1659+
vi, vj = T.axis.remap("SS", [i0, i1])
1660+
vk = T.axis.R(
1661+
T.int64(128),
1662+
i2_outer * T.int64(32) + i2_inner_outer * T.int64(4) + i2_inner_inner,
1663+
)
1664+
with T.init():
1665+
C[vi, vj] = 0.0
1666+
C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
1667+
1668+
@T.prim_func
1669+
def expected(A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
1670+
B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
1671+
C: T.Buffer((T.int64(128), T.int64(128)), "float32"),
1672+
):
1673+
C_rf = T.alloc_buffer((T.int64(4), T.int64(128), T.int64(128)), "float32")
1674+
1675+
for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)):
1676+
with T.block("update_rf"):
1677+
vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer= T.axis.remap("SSSRR", [i2_inner_inner, i0, i1, i2_outer, i2_inner_outer])
1678+
with T.init():
1679+
C_rf[vi2_inner_inner, vi, vj] = 0.0
1680+
C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + (
1681+
A[vi, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * T.int64(4))) + vi2_inner_inner)]
1682+
* B[vj, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * T.int64(4))) + vi2_inner_inner)]
1683+
)
1684+
1685+
for i0_1, i1_1, i2_inner_inner_1 in T.grid(T.int64(128), T.int64(128), T.int64(4)):
1686+
with T.block("update"):
1687+
vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1])
1688+
with T.init():
1689+
C[vi_1, vj_1] = 0.0
1690+
C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
1691+
# fmt: on
1692+
1693+
s = tir.Schedule(before, debug_mask="all")
1694+
update = s.get_block("update")
1695+
_, _, _, _, kii = s.get_loops(update)
1696+
rf_block = s.rfactor(kii, 0)
1697+
assert_structural_equal_ignore_global_symbol(s.mod["main"], expected)
1698+
assert s.get(rf_block).same_as(s.get(s.get_block("update_rf")))
1699+
assert s.get(update).same_as(s.get(s.get_block("update")))
1700+
verify_trace_roundtrip(s, mod=before)
1701+
1702+
16461703
if __name__ == "__main__":
16471704
tvm.testing.main()

0 commit comments

Comments
 (0)