Skip to content

Commit f62a534

Browse files
authored
Merge pull request #4 from junrushao1994/feature/2022-08-15/group2-2
Extract leaf blocks only
2 parents 592af63 + 435bd4f commit f62a534

File tree

2 files changed

+306
-80
lines changed

2 files changed

+306
-80
lines changed

src/meta_schedule/feature_extractor/per_block_feature.cc

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include <unordered_set>
2727
#include <vector>
2828

29-
#include "../utils.h"
3029
#include "./utils.h"
3130

3231
namespace tvm {
@@ -618,7 +617,7 @@ void Feature::SubFeature::SetReuse(const LoopNest& loop_nest,
618617
extent = 1;
619618
}
620619
const IntVec& touch = buffer_touched_under_loop.at(loop_idx);
621-
reuse_dis_iter = std::accumulate(touch.begin(), touch.end(), 1);
620+
reuse_dis_iter = std::accumulate(touch.begin(), touch.end(), 0);
622621
reuse_dis_bytes = 0.0;
623622
int buffer_idx = -1;
624623
for (int64_t numel : touch) {
@@ -677,6 +676,7 @@ Feature::Feature(const BlockRealizeNode* realize, const LoopNest& loop_nest,
677676
for (SubFeature& feature : sub_features) {
678677
const BufferNode* buffer = feature.buffer_;
679678
int64_t numel = buffer_touched_under_loop.front().at(++buffer_idx);
679+
LOG(INFO) << "feature: " << feature.buffer_->name << ", numel = " << numel;
680680
feature.SetFeature(loop_nest, cache_line_bytes, numel * buffer->dtype.bytes());
681681
}
682682
// Step 5. Calculate `for_touched_bytes`
@@ -999,15 +999,15 @@ class PerBlockFeatureCollector : private StmtVisitor {
999999
std::vector<Feature> result;
10001000
result.reserve(collector.block_features_.size());
10011001
for (const BlockRealizeNode* realize : collector.ordered_blocks_) {
1002+
LOG(INFO) << "Block: " << realize->block->name_hint;
10021003
Feature& feature = collector.block_features_.at(realize);
1003-
if (feature.block_realize != nullptr) {
1004-
ICHECK(feature.group1);
1005-
ICHECK(feature.group2);
1006-
ICHECK(feature.group3);
1007-
ICHECK(feature.group4);
1008-
ICHECK(feature.group5);
1009-
result.push_back(std::move(feature));
1010-
}
1004+
ICHECK(feature.block_realize == realize);
1005+
ICHECK(feature.group1);
1006+
ICHECK(feature.group2);
1007+
ICHECK(feature.group3);
1008+
ICHECK(feature.group4);
1009+
ICHECK(feature.group5);
1010+
result.push_back(std::move(feature));
10111011
}
10121012
return result;
10131013
}
@@ -1020,33 +1020,34 @@ class PerBlockFeatureCollector : private StmtVisitor {
10201020
arith_intensity_curve_num_samples_(arith_intensity_curve_num_samples) {}
10211021

10221022
void VisitStmt_(const BlockRealizeNode* realize) final {
1023-
if (!scopes_.empty()) {
1024-
ordered_blocks_.push_back(realize);
1025-
}
1026-
Feature& feature = block_features_[realize];
1027-
feature.block_realize = realize;
1028-
feature.group1 = std::make_unique<group1::Feature>(realize, loop_nest_, is_gpu_);
1023+
ordered_blocks_.push_back(realize);
1024+
int previous_num_blocks_visited = ++this->num_blocks_visited_;
10291025
scopes_.push_back(realize);
10301026
dfs_path_.push_back(realize);
10311027
StmtVisitor::VisitStmt_(realize);
10321028
dfs_path_.pop_back();
10331029
scopes_.pop_back();
1034-
if (!scopes_.empty()) {
1035-
AddArithOpsToScope(&feature.group1->arith_ops);
1036-
}
1037-
IntVec for_touched_bytes;
1038-
feature.group2 = std::make_unique<group2::Feature>(realize, loop_nest_, cache_line_bytes_,
1039-
scopes_.empty() ? nullptr : scopes_.back(),
1040-
&for_touched_bytes, &analyzer_);
1041-
feature.group3 =
1042-
std::make_unique<group3::Feature>(arith_intensity_curve_num_samples_, loop_nest_,
1043-
for_touched_bytes, feature.group1->arith_ops);
1044-
feature.group4 = std::make_unique<group4::Feature>(loop_nest_, realize, &analyzer_,
1045-
&alloc_buffer_outer_loops_);
1046-
feature.group5 = std::make_unique<group5::Feature>(loop_nest_);
1047-
// Erase the feature of the root block
1048-
if (scopes_.empty()) {
1049-
block_features_.erase(realize);
1030+
// only extract features for leaf blocks
1031+
if (previous_num_blocks_visited == this->num_blocks_visited_) {
1032+
IntVec for_touched_bytes;
1033+
Feature& feature = block_features_[realize];
1034+
feature.block_realize = realize;
1035+
if (feature.group1 == nullptr) {
1036+
feature.group1 = std::make_unique<group1::Feature>(realize, loop_nest_, is_gpu_);
1037+
}
1038+
feature.group2 = std::make_unique<group2::Feature>(realize, loop_nest_, cache_line_bytes_,
1039+
scopes_.empty() ? nullptr : scopes_.back(),
1040+
&for_touched_bytes, &analyzer_);
1041+
feature.group3 =
1042+
std::make_unique<group3::Feature>(arith_intensity_curve_num_samples_, loop_nest_,
1043+
for_touched_bytes, feature.group1->arith_ops);
1044+
feature.group4 = std::make_unique<group4::Feature>(loop_nest_, realize, &analyzer_,
1045+
&alloc_buffer_outer_loops_);
1046+
feature.group5 = std::make_unique<group5::Feature>(loop_nest_);
1047+
block_features_.emplace(realize, Feature{});
1048+
} else {
1049+
ordered_blocks_.erase(
1050+
std::find(std::begin(ordered_blocks_), std::end(ordered_blocks_), realize));
10501051
}
10511052
}
10521053

@@ -1069,14 +1070,13 @@ class PerBlockFeatureCollector : private StmtVisitor {
10691070
ICHECK(!scopes_.empty());
10701071
group1::Feature::ArithOps arith_ops;
10711072
arith_ops.AddExpr(store->value, loop_nest_.prod);
1072-
AddArithOpsToScope(&arith_ops);
1073-
}
1074-
1075-
void AddArithOpsToScope(group1::Feature::ArithOps* arith_ops) {
10761073
const BlockRealizeNode* scope = scopes_.back();
1077-
// Add the arith_ops to the parent
1078-
group1::Feature::ArithOps& parent_arith_ops = block_features_[scope].group1->arith_ops;
1079-
#define TVM_FEATURE_MATH_OP_ADD(Name) parent_arith_ops.Name += arith_ops->Name;
1074+
std::unique_ptr<group1::Feature>& feature = block_features_[scope].group1;
1075+
if (feature == nullptr) {
1076+
block_features_[scope].block_realize = scope;
1077+
feature = std::make_unique<group1::Feature>(scope, loop_nest_, is_gpu_);
1078+
}
1079+
#define TVM_FEATURE_MATH_OP_ADD(Name) feature->arith_ops.Name += arith_ops.Name;
10801080
TVM_FEATURE_MATH_OP_ADD(float_mad);
10811081
TVM_FEATURE_MATH_OP_ADD(float_add_sub);
10821082
TVM_FEATURE_MATH_OP_ADD(float_mul);
@@ -1097,6 +1097,7 @@ class PerBlockFeatureCollector : private StmtVisitor {
10971097
}
10981098

10991099
bool is_gpu_;
1100+
int num_blocks_visited_ = 0;
11001101
int64_t cache_line_bytes_;
11011102
int64_t arith_intensity_curve_num_samples_;
11021103
arith::Analyzer analyzer_;
@@ -1134,6 +1135,7 @@ class PerBlockFeatureNode : public FeatureExtractorNode {
11341135
using namespace tvm::tir::per_block_feature;
11351136
static transform::Sequential passes = tir::transform::PassListForFeatureExtraction();
11361137
mod = passes(std::move(mod));
1138+
LOG(INFO) << "mod =\n" << tir::AsTVMScript(mod);
11371139
std::vector<Feature> features = PerBlockFeatureCollector::Collect(
11381140
is_gpu, this->cache_line_bytes, this->arith_intensity_curve_num_samples, mod);
11391141
int n_features = features.size();

0 commit comments

Comments
 (0)