From 395e91ff54543864a90240d18c8efd8c277c758b Mon Sep 17 00:00:00 2001 From: "Kathryn (Jinqi) Chen" <65606304+Kathryn-cat@users.noreply.github.com> Date: Thu, 30 Jun 2022 19:36:13 -0700 Subject: [PATCH] [MetaSchedule] Extract workload embedding (#11975) This PR enables extracting the embeddings of the workload in a tuning context, which further strengthens the feature extracting process. Workload embeddings are extracted based on names of each block in the IR module. If `extract_workload` is enabled, the extracted feature vectors will have length 164 + 8 = 172. --- include/tvm/meta_schedule/feature_extractor.h | 4 +- .../feature_extractor/per_store_feature.py | 6 ++ .../feature_extractor/per_store_feature.cc | 79 ++++++++++++++++++- 3 files changed, 85 insertions(+), 4 deletions(-) diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index 02e9f26b2a..4165e5efe0 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -101,11 +101,13 @@ class FeatureExtractor : public runtime::ObjectRef { * \param arith_intensity_curve_num_samples The number of samples used in the arithmetic intensity * curve. * \param cache_line_bytes The number of bytes in a cache line. + * \param extract_workload Whether to extract features in the workload in tuning context or not. * \return The feature extractor created. */ TVM_DLL static FeatureExtractor PerStoreFeature(int buffers_per_store = 5, int arith_intensity_curve_num_samples = 10, - int cache_line_bytes = 64); + int cache_line_bytes = 64, + bool extract_workload = false); /*! * \brief Create a feature extractor with customized methods on the python-side. * \param f_extract_from The packed function of `ExtractFrom`. diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py index 306934d5f9..078a4af0e3 100644 --- a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py +++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py @@ -35,6 +35,8 @@ class PerStoreFeature(FeatureExtractor): The number of samples used in the arithmetic intensity curve. cache_line_bytes : int The number of bytes in a cache line. + extract_workload : bool + Whether to extract features in the workload in tuning context or not. """ buffers_per_store: int @@ -43,6 +45,8 @@ class PerStoreFeature(FeatureExtractor): """The number of samples used in the arithmetic intensity curve.""" cache_line_bytes: int """The number of bytes in a cache line.""" + extract_workload: bool + """Whether to extract features in the workload in tuning context or not.""" feature_vector_length: int """Length of the feature vector.""" @@ -51,10 +55,12 @@ def __init__( buffers_per_store: int = 5, arith_intensity_curve_num_samples: int = 10, cache_line_bytes: int = 64, + extract_workload: bool = False, ): self.__init_handle_by_constructor__( _ffi_api.FeatureExtractorPerStoreFeature, # type: ignore # pylint: disable=no-member buffers_per_store, arith_intensity_curve_num_samples, cache_line_bytes, + extract_workload, ) diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 93f6767b11..c29e5d61f0 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -1169,6 +1170,64 @@ struct Feature { } // namespace group5 +namespace group6 { + +/*! \brief The auxiliary feature extractor for workloads */ +class WorkloadEmbeddingExtractor : private StmtVisitor { + public: + static std::vector Extract(const IRModule& mod) { + WorkloadEmbeddingExtractor self; + for (const auto& kv : mod->functions) { + if (const PrimFuncNode* func = kv.second.as()) { + self(func->body); + } + } + return self.embedding; + } + + private: + void VisitStmt_(const BlockNode* block) final { + StmtVisitor::VisitStmt_(block); + std::string name = block->name_hint; + std::for_each(name.begin(), name.end(), [](char& c) { c = ::tolower(c); }); + if (name.find("softmax") != std::string::npos) { + embedding[0] = 1.0; + } else if ((name.find("max") != std::string::npos) || (name.find("min") != std::string::npos)) { + embedding[1] = 1.0; + } else if (name.find("add") != std::string::npos) { + embedding[2] = 1.0; + } else if (name.find("batch_matmul") != std::string::npos) { + embedding[3] = 1.0; + } else if (name.find("matmul") != std::string::npos) { + embedding[4] = 1.0; + } else if (name.find("depthwiseconv2d") != std::string::npos) { + embedding[5] = 1.0; + } else if (name.find("conv2d_winograd") != std::string::npos) { + embedding[6] = 1.0; + } else if (name.find("conv2d") != std::string::npos) { + embedding[7] = 1.0; + } + } + + std::vector embedding = std::vector(8, 0.0); +}; + +/*! \brief Group 6 feature */ +struct Feature { + explicit Feature(const IRModule& mod) { + this->feature = WorkloadEmbeddingExtractor::Extract(mod); + } + + void Export(std::vector* v) const { + v->insert(v->end(), std::begin(feature), std::end(feature)); + } + + std::vector feature; // The workload embedding + static constexpr int64_t kCount = 8; +}; + +} // namespace group6 + /*! \brief The feature extracted */ struct Feature { const BufferNode* buffer = nullptr; @@ -1178,6 +1237,7 @@ struct Feature { std::unique_ptr group3 = nullptr; std::unique_ptr group4 = nullptr; std::unique_ptr group5 = nullptr; + std::shared_ptr group6 = nullptr; bool operator<(const Feature& other) const { return buffer_order < other.buffer_order; } }; @@ -1283,6 +1343,7 @@ class PerStoreFeatureNode : public FeatureExtractorNode { int buffers_per_store; int arith_intensity_curve_num_samples; int cache_line_bytes; + bool extract_workload; int feature_vector_length; void VisitAttrs(tvm::AttrVisitor* v) { @@ -1308,7 +1369,6 @@ class PerStoreFeatureNode : public FeatureExtractorNode { feature.group3->Export(&result); feature.group4->Export(&result, feature.group5->outer_prod); feature.group5->Export(&result); - ICHECK_EQ(static_cast(result.size()), feature_vector_length); } } @@ -1317,10 +1377,19 @@ class PerStoreFeatureNode : public FeatureExtractorNode { bool is_gpu = tune_context->target.value()->kind->name == "cuda"; std::vector results; results.resize(candidates.size()); - auto f = [this, is_gpu, &candidates, &results](int, int task_id) -> void { + std::unique_ptr feature_group6 = nullptr; + if (extract_workload) { + feature_group6 = std::make_unique(tune_context->mod.value()); + } + auto f = [this, is_gpu, &feature_group6, &candidates, &results](int, int task_id) -> void { const auto& candidate = candidates[task_id]; std::vector> features; ExtractSingle(DeepCopyIRModule(candidate->sch->mod()), is_gpu, &features); + if (extract_workload) { + for (auto& feature : features) { + feature_group6->Export(&feature); + } + } results[task_id] = tir::utils::AsNDArray(features); }; support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f); @@ -1333,16 +1402,20 @@ class PerStoreFeatureNode : public FeatureExtractorNode { FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, int arith_intensity_curve_num_samples, - int cache_line_bytes) { + int cache_line_bytes, bool extract_workload) { ObjectPtr n = make_object(); n->buffers_per_store = buffers_per_store; n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples; n->cache_line_bytes = cache_line_bytes; + n->extract_workload = extract_workload; n->feature_vector_length = tir::group1::Feature::kCount + // tir::group2::Feature::SubFeature::kCount * buffers_per_store + // arith_intensity_curve_num_samples + // tir::group4::Feature::kCount + // tir::group5::Feature::kCount; + if (extract_workload) { + n->feature_vector_length += tir::group6::Feature::kCount; + } return FeatureExtractor(n); }