2626#include < unordered_set>
2727#include < vector>
2828
29- #include " ../utils.h"
3029#include " ./utils.h"
3130
3231namespace tvm {
@@ -618,7 +617,7 @@ void Feature::SubFeature::SetReuse(const LoopNest& loop_nest,
618617 extent = 1 ;
619618 }
620619 const IntVec& touch = buffer_touched_under_loop.at (loop_idx);
621- reuse_dis_iter = std::accumulate (touch.begin (), touch.end (), 1 );
620+ reuse_dis_iter = std::accumulate (touch.begin (), touch.end (), 0 );
622621 reuse_dis_bytes = 0.0 ;
623622 int buffer_idx = -1 ;
624623 for (int64_t numel : touch) {
@@ -677,6 +676,7 @@ Feature::Feature(const BlockRealizeNode* realize, const LoopNest& loop_nest,
677676 for (SubFeature& feature : sub_features) {
678677 const BufferNode* buffer = feature.buffer_ ;
679678 int64_t numel = buffer_touched_under_loop.front ().at (++buffer_idx);
679+ LOG (INFO) << " feature: " << feature.buffer_ ->name << " , numel = " << numel;
680680 feature.SetFeature (loop_nest, cache_line_bytes, numel * buffer->dtype .bytes ());
681681 }
682682 // Step 5. Calculate `for_touched_bytes`
@@ -999,15 +999,15 @@ class PerBlockFeatureCollector : private StmtVisitor {
999999 std::vector<Feature> result;
10001000 result.reserve (collector.block_features_ .size ());
10011001 for (const BlockRealizeNode* realize : collector.ordered_blocks_ ) {
1002+ LOG (INFO) << " Block: " << realize->block ->name_hint ;
10021003 Feature& feature = collector.block_features_ .at (realize);
1003- if (feature.block_realize != nullptr ) {
1004- ICHECK (feature.group1 );
1005- ICHECK (feature.group2 );
1006- ICHECK (feature.group3 );
1007- ICHECK (feature.group4 );
1008- ICHECK (feature.group5 );
1009- result.push_back (std::move (feature));
1010- }
1004+ ICHECK (feature.block_realize == realize);
1005+ ICHECK (feature.group1 );
1006+ ICHECK (feature.group2 );
1007+ ICHECK (feature.group3 );
1008+ ICHECK (feature.group4 );
1009+ ICHECK (feature.group5 );
1010+ result.push_back (std::move (feature));
10111011 }
10121012 return result;
10131013 }
@@ -1020,33 +1020,34 @@ class PerBlockFeatureCollector : private StmtVisitor {
10201020 arith_intensity_curve_num_samples_(arith_intensity_curve_num_samples) {}
10211021
10221022 void VisitStmt_ (const BlockRealizeNode* realize) final {
1023- if (!scopes_.empty ()) {
1024- ordered_blocks_.push_back (realize);
1025- }
1026- Feature& feature = block_features_[realize];
1027- feature.block_realize = realize;
1028- feature.group1 = std::make_unique<group1::Feature>(realize, loop_nest_, is_gpu_);
1023+ ordered_blocks_.push_back (realize);
1024+ int previous_num_blocks_visited = ++this ->num_blocks_visited_ ;
10291025 scopes_.push_back (realize);
10301026 dfs_path_.push_back (realize);
10311027 StmtVisitor::VisitStmt_ (realize);
10321028 dfs_path_.pop_back ();
10331029 scopes_.pop_back ();
1034- if (!scopes_.empty ()) {
1035- AddArithOpsToScope (&feature.group1 ->arith_ops );
1036- }
1037- IntVec for_touched_bytes;
1038- feature.group2 = std::make_unique<group2::Feature>(realize, loop_nest_, cache_line_bytes_,
1039- scopes_.empty () ? nullptr : scopes_.back (),
1040- &for_touched_bytes, &analyzer_);
1041- feature.group3 =
1042- std::make_unique<group3::Feature>(arith_intensity_curve_num_samples_, loop_nest_,
1043- for_touched_bytes, feature.group1 ->arith_ops );
1044- feature.group4 = std::make_unique<group4::Feature>(loop_nest_, realize, &analyzer_,
1045- &alloc_buffer_outer_loops_);
1046- feature.group5 = std::make_unique<group5::Feature>(loop_nest_);
1047- // Erase the feature of the root block
1048- if (scopes_.empty ()) {
1049- block_features_.erase (realize);
1030+ // only extract features for leaf blocks
1031+ if (previous_num_blocks_visited == this ->num_blocks_visited_ ) {
1032+ IntVec for_touched_bytes;
1033+ Feature& feature = block_features_[realize];
1034+ feature.block_realize = realize;
1035+ if (feature.group1 == nullptr ) {
1036+ feature.group1 = std::make_unique<group1::Feature>(realize, loop_nest_, is_gpu_);
1037+ }
1038+ feature.group2 = std::make_unique<group2::Feature>(realize, loop_nest_, cache_line_bytes_,
1039+ scopes_.empty () ? nullptr : scopes_.back (),
1040+ &for_touched_bytes, &analyzer_);
1041+ feature.group3 =
1042+ std::make_unique<group3::Feature>(arith_intensity_curve_num_samples_, loop_nest_,
1043+ for_touched_bytes, feature.group1 ->arith_ops );
1044+ feature.group4 = std::make_unique<group4::Feature>(loop_nest_, realize, &analyzer_,
1045+ &alloc_buffer_outer_loops_);
1046+ feature.group5 = std::make_unique<group5::Feature>(loop_nest_);
1047+ block_features_.emplace (realize, Feature{});
1048+ } else {
1049+ ordered_blocks_.erase (
1050+ std::find (std::begin (ordered_blocks_), std::end (ordered_blocks_), realize));
10501051 }
10511052 }
10521053
@@ -1069,14 +1070,13 @@ class PerBlockFeatureCollector : private StmtVisitor {
10691070 ICHECK (!scopes_.empty ());
10701071 group1::Feature::ArithOps arith_ops;
10711072 arith_ops.AddExpr (store->value , loop_nest_.prod );
1072- AddArithOpsToScope (&arith_ops);
1073- }
1074-
1075- void AddArithOpsToScope (group1::Feature::ArithOps* arith_ops) {
10761073 const BlockRealizeNode* scope = scopes_.back ();
1077- // Add the arith_ops to the parent
1078- group1::Feature::ArithOps& parent_arith_ops = block_features_[scope].group1 ->arith_ops ;
1079- #define TVM_FEATURE_MATH_OP_ADD (Name ) parent_arith_ops.Name += arith_ops->Name;
1074+ std::unique_ptr<group1::Feature>& feature = block_features_[scope].group1 ;
1075+ if (feature == nullptr ) {
1076+ block_features_[scope].block_realize = scope;
1077+ feature = std::make_unique<group1::Feature>(scope, loop_nest_, is_gpu_);
1078+ }
1079+ #define TVM_FEATURE_MATH_OP_ADD (Name ) feature->arith_ops.Name += arith_ops.Name;
10801080 TVM_FEATURE_MATH_OP_ADD (float_mad);
10811081 TVM_FEATURE_MATH_OP_ADD (float_add_sub);
10821082 TVM_FEATURE_MATH_OP_ADD (float_mul);
@@ -1097,6 +1097,7 @@ class PerBlockFeatureCollector : private StmtVisitor {
10971097 }
10981098
10991099 bool is_gpu_;
1100+ int num_blocks_visited_ = 0 ;
11001101 int64_t cache_line_bytes_;
11011102 int64_t arith_intensity_curve_num_samples_;
11021103 arith::Analyzer analyzer_;
@@ -1134,6 +1135,7 @@ class PerBlockFeatureNode : public FeatureExtractorNode {
11341135 using namespace tvm ::tir::per_block_feature;
11351136 static transform::Sequential passes = tir::transform::PassListForFeatureExtraction ();
11361137 mod = passes (std::move (mod));
1138+ LOG (INFO) << " mod =\n " << tir::AsTVMScript (mod);
11371139 std::vector<Feature> features = PerBlockFeatureCollector::Collect (
11381140 is_gpu, this ->cache_line_bytes , this ->arith_intensity_curve_num_samples , mod);
11391141 int n_features = features.size ();
0 commit comments