Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,20 @@ void CalculateProvidedRequiredRegions(

/******** Main Implementation ********/

void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref,
arith::Analyzer* analyzer) {
while (sref->parent != nullptr) {
sref = sref->parent;
}
const PrimFuncNode* f = GetRootPrimFunc(state->mod, sref->stmt, nullptr);
for (const auto& kv : f->buffer_map) {
const Buffer& buffer = kv.second;
for (const PrimExpr& e : buffer->shape) {
analyzer->MarkGlobalNonNegValue(e);
}
}
}

template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops,
Expand All @@ -692,6 +706,7 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/true);
Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
AddShapeVarBounds(self, scope_root_sref.get(), analyzer);
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope);
Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope);
Expand Down
69 changes: 69 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,5 +1823,74 @@ def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32"))
verify_trace_roundtrip(sch=sch, mod=before)


def test_shape_var_as_bound():
# fmt: off
@T.prim_func
def before(a: T.handle, b: T.handle, c: T.handle):
n = T.int32()
A = T.match_buffer(a, (32, 1, 128))
B = T.match_buffer(b, (32, n, 128))
C = T.match_buffer(c, (32, 1, n))
# with T.block("root"):
C_rf = T.alloc_buffer((128, 32, 1, n))
for ax0_ax1_fused, ax2_fused_1, ax2_fused_0 in T.grid(n * 32, 128, 1):
with T.block("NT_matmul_rf"):
vax2_fused_1 = T.axis.spatial(128, ax2_fused_1)
v0 = T.axis.spatial(32, ax0_ax1_fused // n)
v1 = T.axis.spatial(n, ax0_ax1_fused % n)
vax2_fused_0 = T.axis.reduce(1, ax2_fused_0)
T.reads(A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1], B[v0, v1, vax2_fused_0 * 128 + vax2_fused_1])
T.writes(C_rf[vax2_fused_1, v0, 0, v1])
with T.init():
C_rf[vax2_fused_1, v0, 0, v1] = T.float32(0)
C_rf[vax2_fused_1, v0, 0, v1] = C_rf[vax2_fused_1, v0, 0, v1] + A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1] * B[v0, v1, vax2_fused_0 * 128 + vax2_fused_1]
for ax0_ax1_fused, ax2_fused_1 in T.grid(n * 32, 128):
with T.block("NT_matmul"):
vax2_fused_1 = T.axis.reduce(128, ax2_fused_1)
v0 = T.axis.spatial(32, ax0_ax1_fused // n)
v1 = T.axis.spatial(n, ax0_ax1_fused % n)
T.reads(C_rf[vax2_fused_1, v0, 0, v1])
T.writes(C[v0, 0, v1])
with T.init():
C[v0, 0, v1] = T.float32(0)
C[v0, 0, v1] = C[v0, 0, v1] + C_rf[vax2_fused_1, v0, 0, v1]

@T.prim_func
def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle):
n = T.int32()
B = T.match_buffer(b, (32, n, 128))
C = T.match_buffer(c, (32, 1, n))
# with T.block("root"):
C_rf = T.alloc_buffer((128, 32, 1, n))
for ax0_ax1_fused in range(n * 32):
for ax2_fused_1, ax2_fused_0 in T.grid(128, 1):
with T.block("NT_matmul_rf"):
vax2_fused_1 = T.axis.spatial(128, ax2_fused_1)
v0 = T.axis.spatial(32, ax0_ax1_fused // n)
v1 = T.axis.spatial(n, ax0_ax1_fused % n)
vax2_fused_0 = T.axis.reduce(1, ax2_fused_0)
T.reads(A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1], B[v0, v1, vax2_fused_0 * 128 + vax2_fused_1])
T.writes(C_rf[vax2_fused_1, v0, 0, v1])
with T.init():
C_rf[vax2_fused_1, v0, 0, v1] = T.float32(0)
C_rf[vax2_fused_1, v0, 0, v1] = C_rf[vax2_fused_1, v0, 0, v1] + A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1] * B[v0, v1, vax2_fused_0 * 128 + vax2_fused_1]
for ax0, ax1, ax2 in T.grid(128, 1, 1):
with T.block("NT_matmul"):
vax2_fused_1 = T.axis.reduce(128, ax0)
v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax1)
v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax2)
T.reads(C_rf[vax2_fused_1, v0, 0, v1])
T.writes(C[v0, 0, v1])
with T.init():
C[v0, 0, v1] = T.float32(0)
C[v0, 0, v1] = C[v0, 0, v1] + C_rf[vax2_fused_1, v0, 0, v1]
# fmt: on
sch = tir.Schedule(before, debug_mask="all")
block = sch.get_block("NT_matmul")
loop, _, _ = sch.get_loops(sch.get_block("NT_matmul_rf"))
sch.reverse_compute_at(block, loop, preserve_unit_loops=True)
tvm.ir.assert_structural_equal(sch.mod["main"], expected, True)


if __name__ == "__main__":
tvm.testing.main()