Skip to content

Commit ffa7cde

Browse files
committed
Fix bug in group4 features
1 parent d9b576a commit ffa7cde

File tree

3 files changed

+61
-31
lines changed

3 files changed

+61
-31
lines changed

exp/test_tensorization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,5 +151,5 @@ def get_output(data, lib):
151151

152152
if __name__ == "__main__":
153153
# test_cuda_tensor_core("bert_base", (8, 128))
154-
# test_cuda_matmul()
155-
get_search_space()
154+
test_cuda_matmul()
155+
# get_search_space()

src/meta_schedule/feature_extractor/per_block_feature.cc

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -701,12 +701,14 @@ Feature::Feature(const BlockRealizeNode* realize, const LoopNest& loop_nest,
701701
}
702702
return a.buffer_->name < b.buffer_->name;
703703
});
704+
/*
704705
{
705706
int i = -1;
706707
for (SubFeature& feature : sub_features) {
707708
LOG(INFO) << "Buffer #" << (++i) << ": " << feature.buffer_->name;
708709
}
709710
}
711+
*/
710712
}
711713

712714
} // namespace group2
@@ -794,40 +796,41 @@ struct Feature {
794796
int64_t alloc_size_local = 0;
795797
int64_t alloc_size_shared = 0;
796798
int64_t alloc_size_global = 0;
797-
// alloc_outer_prod * alloc_inner_prod
799+
// alloc_outer_prod * written_inner_prod / buffer size
798800
int64_t alloc_prod_local = 0;
799801
int64_t alloc_prod_shared = 0;
800802
int64_t alloc_prod_global = 0;
801-
// The product of lengths of loops outside the scope of the alloc
802-
int64_t alloc_outer_prod_local = 1;
803-
int64_t alloc_outer_prod_shared = 1;
804-
int64_t alloc_outer_prod_global = 1;
803+
// The product of lengths of loops outside the scope of the alloc * buffer size
804+
int64_t alloc_outer_prod_local = 0;
805+
int64_t alloc_outer_prod_shared = 0;
806+
int64_t alloc_outer_prod_global = 0;
807+
// The product of lengths of loops inside the scope of alloc before the buffer is written to *
808+
// buffer size
809+
int64_t written_inner_prod_local = 0;
810+
int64_t written_inner_prod_shared = 0;
811+
int64_t written_inner_prod_global = 0;
805812

806813
static constexpr int64_t kCount = 12;
807814

808-
void Export(std::vector<double>* v, int64_t outer_prod) const {
815+
void Export(std::vector<double>* v) const {
809816
double vs[] = {
810-
slog(alloc_size_local),
811-
slog(alloc_size_shared),
812-
slog(alloc_size_global),
813-
slog(alloc_prod_local),
814-
slog(alloc_prod_shared),
815-
slog(alloc_prod_global),
816-
slog(alloc_outer_prod_local),
817-
slog(alloc_outer_prod_shared),
818-
slog(alloc_outer_prod_global),
819-
slog(static_cast<double>(outer_prod) / alloc_outer_prod_local),
820-
slog(static_cast<double>(outer_prod) / alloc_outer_prod_shared),
821-
slog(static_cast<double>(outer_prod) / alloc_outer_prod_global),
817+
slog(alloc_size_local), slog(alloc_size_shared),
818+
slog(alloc_size_global), slog(alloc_prod_local),
819+
slog(alloc_prod_shared), slog(alloc_prod_global),
820+
slog(alloc_outer_prod_local), slog(alloc_outer_prod_shared),
821+
slog(alloc_outer_prod_global), slog(written_inner_prod_local),
822+
slog(written_inner_prod_shared), slog(written_inner_prod_global),
822823
};
823824
v->insert(v->end(), std::begin(vs), std::end(vs));
824825
}
825826

826827
Feature() = default;
827828

828829
explicit Feature(const LoopNest& loop_nest, const BlockRealizeNode* realize,
829-
arith::Analyzer* analyzer) {
830+
arith::Analyzer* analyzer,
831+
std::unordered_map<const Buffer*, int64_t>* alloc_buffer_outer_loops_) {
830832
for (const Buffer& buffer : realize->block->alloc_buffers) {
833+
(*alloc_buffer_outer_loops_)[&buffer] = loop_nest.prod;
831834
std::vector<int64_t> shape = utils::GetBufferShape(buffer, analyzer);
832835
int64_t numel = 1;
833836
for (int64_t x : shape) {
@@ -837,18 +840,43 @@ struct Feature {
837840
switch (storage_scope.rank) {
838841
case runtime::StorageRank::kLocal:
839842
alloc_size_local += numel * buffer->dtype.bytes();
840-
alloc_prod_local += numel * loop_nest.prod;
841-
alloc_outer_prod_local *= loop_nest.prod;
843+
alloc_outer_prod_local += numel * loop_nest.prod;
842844
break;
843845
case runtime::StorageRank::kShared:
844846
alloc_size_shared += numel * buffer->dtype.bytes();
845-
alloc_prod_shared += numel * loop_nest.prod;
846-
alloc_outer_prod_shared *= loop_nest.prod;
847+
alloc_outer_prod_shared += numel * loop_nest.prod;
847848
break;
848849
case runtime::StorageRank::kGlobal:
849850
alloc_size_global += numel * buffer->dtype.bytes();
851+
alloc_outer_prod_global += numel * loop_nest.prod;
852+
break;
853+
default:
854+
break;
855+
}
856+
}
857+
858+
const Array<BufferRegion> write_buffers = realize->block->writes;
859+
for (const BufferRegion& write_buffer : write_buffers) {
860+
const Buffer& buffer = write_buffer->buffer;
861+
std::vector<int64_t> shape = utils::GetBufferShape(buffer, analyzer);
862+
int64_t numel = 1;
863+
for (int64_t x : shape) {
864+
numel *= x;
865+
}
866+
int64_t outer_loops = (*alloc_buffer_outer_loops_)[&buffer];
867+
runtime::StorageScope storage_scope = runtime::StorageScope::Create(buffer.scope());
868+
switch (storage_scope.rank) {
869+
case runtime::StorageRank::kLocal:
870+
alloc_prod_local += numel * loop_nest.prod;
871+
written_inner_prod_local += numel * (static_cast<double>(loop_nest.prod) / outer_loops);
872+
break;
873+
case runtime::StorageRank::kShared:
874+
alloc_prod_shared += numel * loop_nest.prod;
875+
written_inner_prod_shared += numel * (static_cast<double>(loop_nest.prod) / outer_loops);
876+
break;
877+
case runtime::StorageRank::kGlobal:
850878
alloc_prod_global += numel * loop_nest.prod;
851-
alloc_outer_prod_global *= loop_nest.prod;
879+
written_inner_prod_global += numel * (static_cast<double>(loop_nest.prod) / outer_loops);
852880
break;
853881
default:
854882
break;
@@ -1013,7 +1041,8 @@ class PerBlockFeatureCollector : private StmtVisitor {
10131041
feature.group3 =
10141042
std::make_unique<group3::Feature>(arith_intensity_curve_num_samples_, loop_nest_,
10151043
for_touched_bytes, feature.group1->arith_ops);
1016-
feature.group4 = std::make_unique<group4::Feature>(loop_nest_, realize, &analyzer_);
1044+
feature.group4 = std::make_unique<group4::Feature>(loop_nest_, realize, &analyzer_,
1045+
&alloc_buffer_outer_loops_);
10171046
feature.group5 = std::make_unique<group5::Feature>(loop_nest_);
10181047
// Erase the feature of the root block
10191048
if (scopes_.empty()) {
@@ -1074,6 +1103,7 @@ class PerBlockFeatureCollector : private StmtVisitor {
10741103
std::vector<const BlockRealizeNode*> scopes_;
10751104
std::vector<const StmtNode*> dfs_path_;
10761105
LoopNest loop_nest_ = {};
1106+
std::unordered_map<const Buffer*, int64_t> alloc_buffer_outer_loops_ = {};
10771107
std::unordered_map<const BlockRealizeNode*, Feature> block_features_ = {};
10781108
std::vector<const BlockRealizeNode*> ordered_blocks_;
10791109
};
@@ -1115,7 +1145,7 @@ class PerBlockFeatureNode : public FeatureExtractorNode {
11151145
feature.group1->Export(&result);
11161146
feature.group2->Export(&result, this->buffers_per_block);
11171147
feature.group3->Export(&result);
1118-
feature.group4->Export(&result, feature.group5->outer_prod);
1148+
feature.group4->Export(&result);
11191149
feature.group5->Export(&result);
11201150
}
11211151
}
@@ -1141,8 +1171,8 @@ class PerBlockFeatureNode : public FeatureExtractorNode {
11411171
}
11421172
results[task_id] = tir::utils::AsNDArray(features);
11431173
};
1144-
f(0, 0);
1145-
// support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f);
1174+
// f(0, 0);
1175+
support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f);
11461176
return results;
11471177
}
11481178

src/meta_schedule/feature_extractor/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ inline std::tuple<int64_t, IntVec> RelaxAndUnion(const std::vector<NDIntSet>& re
272272
for (int j = 0; j < n_regions; ++j) {
273273
int_sets.push_back(regions[j][i]);
274274
}
275-
LOG(INFO) << "dim " << i << ": " << int_sets;
275+
// LOG(INFO) << "dim " << i << ": " << int_sets;
276276
arith::IntSet union_set = arith::Union(int_sets);
277277
// Update the area
278278
int64_t min = analyzer->const_int_bound(union_set.min())->min_value;

0 commit comments

Comments
 (0)