|
13 | 13 | #include "schedule.h"
|
14 | 14 | #include "tree_builder.h"
|
15 | 15 | #include "util.h"
|
| 16 | +#include "../arithmetic/int_set.h" |
16 | 17 |
|
17 | 18 | namespace tvm {
|
18 | 19 |
|
@@ -388,26 +389,26 @@ BlockTreeNode Schedule::compute_at(BlockTreeNode block, AxisTreeNode axis) {
|
388 | 389 | new_args[i] = block->args[i];
|
389 | 390 | }
|
390 | 391 |
|
391 |
| - if (const arith::IntervalSet* set = iter_domain[i].as<arith::IntervalSet>()) { |
| 392 | + if (const arith::IntervalSetNode* set = iter_domain[i].as<arith::IntervalSetNode>()) { |
392 | 393 | node = AxisTreeNodeNode::make(iter_var,
|
393 |
| - set->i.min, |
394 |
| - set->i.max - set->i.min + 1, |
| 394 | + set->min_value, |
| 395 | + set->max_value - set->min_value + 1, |
395 | 396 | kOpaque, // todo(lmzheng): fill correct type to replace kOpaque x 3
|
396 | 397 | Array<ScheduleTreeNode>{last});
|
397 | 398 | new_args[i] = iter_var;
|
398 |
| - } else if (const arith::StrideSet* set = iter_domain[i].as<arith::StrideSet>()) { |
| 399 | + } else if (const arith::StrideSetNode* set = iter_domain[i].as<arith::StrideSetNode>()) { |
399 | 400 | CHECK(set->extents.size() == 1);
|
400 |
| - CHECK(set->base.is_single_point()); |
| 401 | + CHECK(is_one(set->base_extent)); |
401 | 402 | if (is_one(set->extents[0])) {
|
402 | 403 | node = AxisTreeNode(nullptr);
|
403 |
| - new_args[i] = set->base.min; |
| 404 | + new_args[i] = set->base_min; |
404 | 405 | } else {
|
405 | 406 | node = AxisTreeNodeNode::make(iter_var,
|
406 | 407 | 0,
|
407 | 408 | set->extents[0],
|
408 | 409 | kOpaque,
|
409 | 410 | Array<ScheduleTreeNode>{last});
|
410 |
| - new_args[i] = iter_var * set->strides[0] + set->base.min; |
| 411 | + new_args[i] = iter_var * set->strides[0] + set->base_min; |
411 | 412 | }
|
412 | 413 | } else {
|
413 | 414 | LOG(FATAL) << "Cannot handle int set : " << iter_domain[i];
|
@@ -539,11 +540,11 @@ BlockTreeNode Schedule::blockize(AxisTreeNode axis) {
|
539 | 540 | for (size_t i = 0; i < iter.first.ndim(); ++i) {
|
540 | 541 | Array<IntSet> to_merge;
|
541 | 542 | for (const std::vector<IntSet>& y : iter.second) {
|
542 |
| - const arith::IntervalSet* set = y[i].as<arith::IntervalSet>(); |
| 543 | + const arith::IntervalSetNode* set = y[i].as<arith::IntervalSetNode>(); |
543 | 544 | CHECK(set != nullptr);
|
544 |
| - IntSet b = arith::IntervalSet::make( |
545 |
| - SubstituteAndEquationSimplify(set->i.min, var_map, &analyzer), |
546 |
| - SubstituteAndEquationSimplify(set->i.max, var_map, &analyzer)); |
| 545 | + IntSet b = arith::IntervalSet( |
| 546 | + SubstituteAndEquationSimplify(set->min_value, var_map, &analyzer), |
| 547 | + SubstituteAndEquationSimplify(set->max_value, var_map, &analyzer)); |
547 | 548 | to_merge.push_back(b);
|
548 | 549 | }
|
549 | 550 | IntSet merged = arith::Union(to_merge);
|
@@ -600,7 +601,7 @@ BlockTreeNode Schedule::tensorize(BlockTreeNode block, TensorIntrinsic intrin) {
|
600 | 601 | block->inputs, block->outputs,
|
601 | 602 | Stmt(NodePtr<Node>(nullptr)),
|
602 | 603 | Array<ScheduleTreeNode>{Downcast<ScheduleTreeNode>(ret)});
|
603 |
| - } else if (ret->derived_from<HalideIR::Internal::BaseStmtNode>()) { |
| 604 | + } else if (ret->derived_from<StmtNode>()) { |
604 | 605 | new_block = BlockTreeNodeNode::make(block->args, block->vars,
|
605 | 606 | block->inputs, block->outputs,
|
606 | 607 | Downcast<Stmt>(ret), Array<ScheduleTreeNode>{});
|
|
0 commit comments