Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1581,14 +1581,14 @@ std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const tir::Sche
tir::IterVarType type = GetLoopIterType(loop_sref);
if (type == tir::kDataPar) {
const int64_t* extent = GetLoopIntExtent(loop_sref);
if (*extent != -1) {
if (extent && *extent != -1) {
cum_space_len *= *extent;
} else {
return std::make_pair(-1, -1);
}
} else if (type == tir::kCommReduce) {
const int64_t* extent = GetLoopIntExtent(loop_sref);
if (*extent != -1) {
if (extent && *extent != -1) {
cum_reduce_len *= *extent;
} else {
return std::make_pair(-1, -1);
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@ Array<ExprRV> ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int
int max_innermost_factor,
Optional<Array<Integer>> decision) {
TVM_TIR_SCHEDULE_BEGIN();
// use None RV object to denotes auto-infer tile factors.
return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n,
max_innermost_factor, &decision));
max_innermost_factor, &decision),
/*convert_negone_to_none=*/true);
TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_);
throw;
}
Expand Down
12 changes: 10 additions & 2 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,12 @@ class ConcreteScheduleNode : public ScheduleNode {
/*!
* \brief Add a list of integers as random variables into the symbol table
* \param value The list of integers to be added to the symbol table
* \param convert_negone_to_none Convert negative one to none RV.
* Which is convention of certain primitives.
* \return The new random variables created
*/
inline Array<ExprRV> CreateRV(const std::vector<int64_t>& value);
inline Array<ExprRV> CreateRV(const std::vector<int64_t>& value,
bool convert_negone_to_none = false);
/*! \brief Remove a random variable from the symbol table */
inline void RemoveFromSymbolTable(const ObjectRef& rv);
/*!
Expand Down Expand Up @@ -360,10 +363,15 @@ inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) {
return std::move(rv);
}

inline Array<ExprRV> ConcreteScheduleNode::CreateRV(const std::vector<int64_t>& value) {
inline Array<ExprRV> ConcreteScheduleNode::CreateRV(const std::vector<int64_t>& value,
bool convert_negone_to_none) {
Array<ExprRV> results;
results.reserve(value.size());
for (int64_t v : value) {
if (convert_negone_to_none && v == -1) {
results.push_back(ExprRV(nullptr));
continue;
}
results.push_back(CreateRV(v));
}
return results;
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ Array<String> TranslateAddOutputRVs(
ICHECK(!rv_names->count(output))
<< "ValueError: The random variable has been produced once: " << rv_names->at(output);
String result{ObjectPtr<StringObj>{nullptr}};
if (output->IsInstance<BlockRVNode>()) {
if (!output.defined()) {
result = "_";
} else if (output->IsInstance<BlockRVNode>()) {
result = "b" + std::to_string(i);
} else if (output->IsInstance<LoopRVNode>()) {
result = "l" + std::to_string(i);
Expand Down
8 changes: 5 additions & 3 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array<runtime::Int>& candidat
Array<ExprRV> TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n,
int max_innermost_factor,
Optional<Array<Integer>> decision) {
Array<ExprRV> results = CreateRV(tir::SamplePerfectTile(
&this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision));

// use None RV object to denotes auto-infer tile factors.
Array<ExprRV> results =
CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n,
max_innermost_factor, &decision),
/*convert_negone_to_none=*/true);
static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
/*inputs=*/{loop_rv},
Expand Down
28 changes: 28 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,33 @@ def test_sample_perfect_tile_after_copy():
sch_copy.sample_perfect_tile(i, n=4)


def test_sample_perfect_tile_on_dynamic_loops():
"""Currently dynamic loop is trivially tiled"""

@T.prim_func
def workload(a: T.handle) -> None:
n = T.int32()
A = T.match_buffer(a, (n, 1024))
for i, j in T.grid(n, 1024):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
A[vi, vj] = 1.0

sch = tir.Schedule(workload, debug_mask="all")
di, si = sch.get_loops(sch.get_block("B"))

factors = sch.sample_perfect_tile(si, n=4)
factors = [sch.get(i) for i in factors]
prod = factors[0] * factors[1] * factors[2] * factors[3]
assert prod == 1024

factors = sch.sample_perfect_tile(di, n=4)
assert factors[0] is None
factors = [sch.get(i) for i in factors[1:]]
prod = factors[0] * factors[1] * factors[2]
assert prod == 1
verify_trace_roundtrip(sch, mod=workload)


if __name__ == "__main__":
tvm.testing.main()
35 changes: 35 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_split_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,41 @@ def test_split_with_inferred_factor():
verify_trace_roundtrip(sch=sch, mod=elementwise)


def test_split_with_dynamic_inferred_factor():
@T.prim_func
def before(a: T.handle, b: T.handle) -> None:
N = T.int32()
M = T.int32()
A = T.match_buffer(a, (N, 128, M))
B = T.match_buffer(b, (N, 128, M))
for i, j, k in T.grid(N, 128, M):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
B[vi, vj, vk] = A[vi, vj, vk] * 2.0

@T.prim_func
def expected(a: T.handle, b: T.handle) -> None:
N, M = T.int32(), T.int32()
A = T.match_buffer(a, (N, 128, M))
B = T.match_buffer(b, (N, 128, M))
for i_0, i_1, j_0, j_1, k_0, k_1 in T.grid((N + 15) // 16, 16, 4, 32, 16, (M + 15) // 16):
with T.block("B"):
vi = T.axis.spatial(N, i_0 * 16 + i_1)
vj = T.axis.spatial(128, j_0 * 32 + j_1)
vk = T.axis.spatial(M, k_0 * ((M + 15) // 16) + k_1)
T.where(i_0 * 16 + i_1 < N and k_0 * ((M + 15) // 16) + k_1 < M)
B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2.0)

sch = tir.Schedule(before, debug_mask="all")
block_b = sch.get_block("B")
i, j, k = sch.get_loops(block_b)
sch.split(i, factors=[None, 16])
sch.split(j, factors=[4, 32])
sch.split(k, factors=[16, None])
assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=before)


def test_split_with_predicate():
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
Expand Down