|
16 | 16 | # under the License. |
17 | 17 | # pylint: disable=missing-function-docstring,missing-module-docstring |
18 | 18 | import pytest |
| 19 | + |
19 | 20 | import tvm |
20 | 21 | import tvm.testing |
21 | 22 | from tvm import te, tir, topi |
@@ -1643,5 +1644,61 @@ def test_reduction_rfactor_topi_argmin(): |
1643 | 1644 | verify_trace_roundtrip(s, mod=argmin_topi) |
1644 | 1645 |
|
1645 | 1646 |
|
| 1647 | +def test_reduction_rfactor_int64(): |
| 1648 | + # fmt: off |
| 1649 | + @T.prim_func |
| 1650 | + def before( |
| 1651 | + A: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| 1652 | + B: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| 1653 | + C: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| 1654 | + ): |
| 1655 | + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid( |
| 1656 | + T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4) |
| 1657 | + ): |
| 1658 | + with T.block("update"): |
| 1659 | + vi, vj = T.axis.remap("SS", [i0, i1]) |
| 1660 | + vk = T.axis.R( |
| 1661 | + T.int64(128), |
| 1662 | + i2_outer * T.int64(32) + i2_inner_outer * T.int64(4) + i2_inner_inner, |
| 1663 | + ) |
| 1664 | + with T.init(): |
| 1665 | + C[vi, vj] = 0.0 |
| 1666 | + C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) |
| 1667 | + |
| 1668 | + @T.prim_func |
| 1669 | + def expected(A: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| 1670 | + B: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| 1671 | + C: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| 1672 | + ): |
| 1673 | + C_rf = T.alloc_buffer((T.int64(4), T.int64(128), T.int64(128)), "float32") |
| 1674 | + |
| 1675 | + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)): |
| 1676 | + with T.block("update_rf"): |
| 1677 | + vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer= T.axis.remap("SSSRR", [i2_inner_inner, i0, i1, i2_outer, i2_inner_outer]) |
| 1678 | + with T.init(): |
| 1679 | + C_rf[vi2_inner_inner, vi, vj] = 0.0 |
| 1680 | + C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( |
| 1681 | + A[vi, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * T.int64(4))) + vi2_inner_inner)] |
| 1682 | + * B[vj, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * T.int64(4))) + vi2_inner_inner)] |
| 1683 | + ) |
| 1684 | + |
| 1685 | + for i0_1, i1_1, i2_inner_inner_1 in T.grid(T.int64(128), T.int64(128), T.int64(4)): |
| 1686 | + with T.block("update"): |
| 1687 | + vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1]) |
| 1688 | + with T.init(): |
| 1689 | + C[vi_1, vj_1] = 0.0 |
| 1690 | + C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] |
| 1691 | + # fmt: on |
| 1692 | + |
| 1693 | + s = tir.Schedule(before, debug_mask="all") |
| 1694 | + update = s.get_block("update") |
| 1695 | + _, _, _, _, kii = s.get_loops(update) |
| 1696 | + rf_block = s.rfactor(kii, 0) |
| 1697 | + assert_structural_equal_ignore_global_symbol(s.mod["main"], expected) |
| 1698 | + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) |
| 1699 | + assert s.get(update).same_as(s.get(s.get_block("update"))) |
| 1700 | + verify_trace_roundtrip(s, mod=before) |
| 1701 | + |
| 1702 | + |
1646 | 1703 | if __name__ == "__main__": |
1647 | 1704 | tvm.testing.main() |
0 commit comments