-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[NFC][IR2Vec] Removing Dimension from Embedder::Create
#142486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Embedder::Create
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesThis PR removes the necessity to know the dimension of the embeddings while invoking (Tracking issue - #141817) Full diff: https://github.com/llvm/llvm-project/pull/142486.diff 4 Files Affected:
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 377c2aec44475..4f8fb3f59ca19 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -469,7 +469,6 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
return;
}
const ir2vec::Vocab &Vocabulary = VocabRes.getVocabulary();
- unsigned Dimension = VocabRes.getDimension();
Note that ``IR2VecVocabAnalysis`` pass is immutable.
@@ -481,7 +480,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
// Assuming F is an llvm::Function&
// For example, using IR2VecKind::Symbolic:
Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
- ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary, Dimension);
+ ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
if (auto Err = EmbOrErr.takeError()) {
// Handle error in embedder creation
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 288753b3b3b8f..9fd1b0ae8e248 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -84,7 +84,7 @@ class Embedder {
mutable BBEmbeddingsMap BBVecMap;
mutable InstEmbeddingsMap InstVecMap;
- Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
+ Embedder(const Function &F, const Vocab &Vocabulary);
/// Helper function to compute embeddings. It generates embeddings for all
/// the instructions and basic blocks in the function F. Logic of computing
@@ -110,10 +110,8 @@ class Embedder {
virtual ~Embedder() = default;
/// Factory method to create an Embedder object.
- static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
- const Function &F,
- const Vocab &Vocabulary,
- unsigned Dimension);
+ static Expected<std::unique_ptr<Embedder>>
+ create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
/// Returns a map containing instructions and the corresponding embeddings for
/// the function F if it has been computed. If not, it computes the embeddings
@@ -149,9 +147,8 @@ class SymbolicEmbedder : public Embedder {
void computeEmbeddings(const BasicBlock &BB) const override;
public:
- SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
- unsigned Dimension)
- : Embedder(F, Vocabulary, Dimension) {
+ SymbolicEmbedder(const Function &F, const Vocab &Vocabulary)
+ : Embedder(F, Vocabulary) {
FuncVector = Embedding(Dimension, 0);
}
};
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 67af44dcac424..490db5fdcdf99 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -59,19 +59,16 @@ AnalysisKey IR2VecVocabAnalysis::Key;
// Embedder and its subclasses
//===----------------------------------------------------------------------===//
-Embedder::Embedder(const Function &F, const Vocab &Vocabulary,
- unsigned Dimension)
- : F(F), Vocabulary(Vocabulary), Dimension(Dimension),
- OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
-}
+Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
+ : F(F), Vocabulary(Vocabulary),
+ Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
+ TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
-Expected<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
- const Function &F,
- const Vocab &Vocabulary,
- unsigned Dimension) {
+Expected<std::unique_ptr<Embedder>>
+Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
switch (Mode) {
case IR2VecKind::Symbolic:
- return std::make_unique<SymbolicEmbedder>(F, Vocabulary, Dimension);
+ return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
}
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
}
@@ -286,10 +283,9 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");
auto Vocab = IR2VecVocabResult.getVocabulary();
- auto Dim = IR2VecVocabResult.getDimension();
for (Function &F : M) {
Expected<std::unique_ptr<Embedder>> EmbOrErr =
- Embedder::create(IR2VecKind::Symbolic, F, Vocab, Dim);
+ Embedder::create(IR2VecKind::Symbolic, F, Vocab);
if (auto Err = EmbOrErr.takeError()) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 0158038b59b6c..9e47b2cd8bedd 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -28,8 +28,7 @@ namespace {
class TestableEmbedder : public Embedder {
public:
- TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
- : Embedder(F, V, Dim) {}
+ TestableEmbedder(const Function &F, const Vocab &V) : Embedder(F, V) {}
void computeEmbeddings() const override {}
void computeEmbeddings(const BasicBlock &BB) const override {}
using Embedder::lookupVocab;
@@ -50,7 +49,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_TRUE(static_cast<bool>(Result));
auto *Emb = Result->get();
@@ -66,7 +65,7 @@ TEST(IR2VecTest, CreateInvalidMode) {
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
// static_cast an invalid int to IR2VecKind
- auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V, 2);
+ auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
EXPECT_FALSE(static_cast<bool>(Result));
std::string ErrMsg;
@@ -123,7 +122,7 @@ TEST(IR2VecTest, LookupVocab) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
- TestableEmbedder E(*F, V, 2);
+ TestableEmbedder E(*F, V);
auto V_foo = E.lookupVocab("foo");
EXPECT_EQ(V_foo.size(), 2u);
EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0));
@@ -190,7 +189,7 @@ struct GetterTestEnv {
Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
Ret = ReturnInst::Create(Ctx, Add, BB);
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_TRUE(static_cast<bool>(Result));
Emb = std::move(*Result);
}
|
Embedder::Create
Embedder::Create
IIUC the main thing is that the dimension is trivially obtainable from the rest of the data (i.e. not so much about downstream-ness, as more about future maintainability - why duplicate information) |
That's right. Having to pass dimension would complicate the consumers in a way that they are expected to know the dimension apriori (which in a way necessitates tracking the dimension). |
Merge activity
|
This PR removes the necessity to know the dimension of the embeddings while invoking `Embedder::Create`. Having the `Dimension` parameter introduces complexities in downstream consumers. (Tracking issue - llvm#141817)
This PR removes the necessity to know the dimension of the embeddings while invoking
Embedder::Create
. Having theDimension
parameter introduces complexities in downstream consumers.(Tracking issue - #141817)