Skip to content

Commit

Permalink
[MetaSchedule] Extract workload embedding (#11975)
Browse files Browse the repository at this point in the history
This PR enables extracting the embeddings of the workload in a tuning context, which further strengthens the feature extracting process. Workload embeddings are extracted based on names of each block in the IR module. If `extract_workload` is enabled, the extracted feature vectors will have length 164 + 8 = 172.
  • Loading branch information
Kathryn-cat authored Jul 1, 2022
1 parent ec39199 commit 395e91f
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 4 deletions.
4 changes: 3 additions & 1 deletion include/tvm/meta_schedule/feature_extractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ class FeatureExtractor : public runtime::ObjectRef {
* \param arith_intensity_curve_num_samples The number of samples used in the arithmetic intensity
* curve.
* \param cache_line_bytes The number of bytes in a cache line.
* \param extract_workload Whether to extract features in the workload in tuning context or not.
* \return The feature extractor created.
*/
TVM_DLL static FeatureExtractor PerStoreFeature(int buffers_per_store = 5,
int arith_intensity_curve_num_samples = 10,
int cache_line_bytes = 64);
int cache_line_bytes = 64,
bool extract_workload = false);
/*!
* \brief Create a feature extractor with customized methods on the python-side.
* \param f_extract_from The packed function of `ExtractFrom`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class PerStoreFeature(FeatureExtractor):
The number of samples used in the arithmetic intensity curve.
cache_line_bytes : int
The number of bytes in a cache line.
extract_workload : bool
Whether to extract features in the workload in tuning context or not.
"""

buffers_per_store: int
Expand All @@ -43,6 +45,8 @@ class PerStoreFeature(FeatureExtractor):
"""The number of samples used in the arithmetic intensity curve."""
cache_line_bytes: int
"""The number of bytes in a cache line."""
extract_workload: bool
"""Whether to extract features in the workload in tuning context or not."""
feature_vector_length: int
"""Length of the feature vector."""

Expand All @@ -51,10 +55,12 @@ def __init__(
buffers_per_store: int = 5,
arith_intensity_curve_num_samples: int = 10,
cache_line_bytes: int = 64,
extract_workload: bool = False,
):
self.__init_handle_by_constructor__(
_ffi_api.FeatureExtractorPerStoreFeature, # type: ignore # pylint: disable=no-member
buffers_per_store,
arith_intensity_curve_num_samples,
cache_line_bytes,
extract_workload,
)
79 changes: 76 additions & 3 deletions src/meta_schedule/feature_extractor/per_store_feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cmath>
#include <memory>
#include <numeric>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -1169,6 +1170,64 @@ struct Feature {

} // namespace group5

namespace group6 {

/*! \brief The auxiliary feature extractor for workloads */
class WorkloadEmbeddingExtractor : private StmtVisitor {
public:
static std::vector<double> Extract(const IRModule& mod) {
WorkloadEmbeddingExtractor self;
for (const auto& kv : mod->functions) {
if (const PrimFuncNode* func = kv.second.as<PrimFuncNode>()) {
self(func->body);
}
}
return self.embedding;
}

private:
void VisitStmt_(const BlockNode* block) final {
StmtVisitor::VisitStmt_(block);
std::string name = block->name_hint;
std::for_each(name.begin(), name.end(), [](char& c) { c = ::tolower(c); });
if (name.find("softmax") != std::string::npos) {
embedding[0] = 1.0;
} else if ((name.find("max") != std::string::npos) || (name.find("min") != std::string::npos)) {
embedding[1] = 1.0;
} else if (name.find("add") != std::string::npos) {
embedding[2] = 1.0;
} else if (name.find("batch_matmul") != std::string::npos) {
embedding[3] = 1.0;
} else if (name.find("matmul") != std::string::npos) {
embedding[4] = 1.0;
} else if (name.find("depthwiseconv2d") != std::string::npos) {
embedding[5] = 1.0;
} else if (name.find("conv2d_winograd") != std::string::npos) {
embedding[6] = 1.0;
} else if (name.find("conv2d") != std::string::npos) {
embedding[7] = 1.0;
}
}

std::vector<double> embedding = std::vector<double>(8, 0.0);
};

/*! \brief Group 6 feature */
struct Feature {
explicit Feature(const IRModule& mod) {
this->feature = WorkloadEmbeddingExtractor::Extract(mod);
}

void Export(std::vector<double>* v) const {
v->insert(v->end(), std::begin(feature), std::end(feature));
}

std::vector<double> feature; // The workload embedding
static constexpr int64_t kCount = 8;
};

} // namespace group6

/*! \brief The feature extracted */
struct Feature {
const BufferNode* buffer = nullptr;
Expand All @@ -1178,6 +1237,7 @@ struct Feature {
std::unique_ptr<group3::Feature> group3 = nullptr;
std::unique_ptr<group4::Feature> group4 = nullptr;
std::unique_ptr<group5::Feature> group5 = nullptr;
std::shared_ptr<group6::Feature> group6 = nullptr;

bool operator<(const Feature& other) const { return buffer_order < other.buffer_order; }
};
Expand Down Expand Up @@ -1283,6 +1343,7 @@ class PerStoreFeatureNode : public FeatureExtractorNode {
int buffers_per_store;
int arith_intensity_curve_num_samples;
int cache_line_bytes;
bool extract_workload;
int feature_vector_length;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand All @@ -1308,7 +1369,6 @@ class PerStoreFeatureNode : public FeatureExtractorNode {
feature.group3->Export(&result);
feature.group4->Export(&result, feature.group5->outer_prod);
feature.group5->Export(&result);
ICHECK_EQ(static_cast<int>(result.size()), feature_vector_length);
}
}

Expand All @@ -1317,10 +1377,19 @@ class PerStoreFeatureNode : public FeatureExtractorNode {
bool is_gpu = tune_context->target.value()->kind->name == "cuda";
std::vector<runtime::NDArray> results;
results.resize(candidates.size());
auto f = [this, is_gpu, &candidates, &results](int, int task_id) -> void {
std::unique_ptr<tir::group6::Feature> feature_group6 = nullptr;
if (extract_workload) {
feature_group6 = std::make_unique<tir::group6::Feature>(tune_context->mod.value());
}
auto f = [this, is_gpu, &feature_group6, &candidates, &results](int, int task_id) -> void {
const auto& candidate = candidates[task_id];
std::vector<std::vector<double>> features;
ExtractSingle(DeepCopyIRModule(candidate->sch->mod()), is_gpu, &features);
if (extract_workload) {
for (auto& feature : features) {
feature_group6->Export(&feature);
}
}
results[task_id] = tir::utils::AsNDArray(features);
};
support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f);
Expand All @@ -1333,16 +1402,20 @@ class PerStoreFeatureNode : public FeatureExtractorNode {

FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store,
int arith_intensity_curve_num_samples,
int cache_line_bytes) {
int cache_line_bytes, bool extract_workload) {
ObjectPtr<PerStoreFeatureNode> n = make_object<PerStoreFeatureNode>();
n->buffers_per_store = buffers_per_store;
n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples;
n->cache_line_bytes = cache_line_bytes;
n->extract_workload = extract_workload;
n->feature_vector_length = tir::group1::Feature::kCount + //
tir::group2::Feature::SubFeature::kCount * buffers_per_store + //
arith_intensity_curve_num_samples + //
tir::group4::Feature::kCount + //
tir::group5::Feature::kCount;
if (extract_workload) {
n->feature_vector_length += tir::group6::Feature::kCount;
}
return FeatureExtractor(n);
}

Expand Down

0 comments on commit 395e91f

Please sign in to comment.