Skip to content

[NFC][IR2Vec] Removing Dimension from Embedder::Create #142486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2025

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Jun 2, 2025

This PR removes the necessity to know the dimension of the embeddings while invoking Embedder::Create. Having the Dimension parameter introduces complexities in downstream consumers.

(Tracking issue - #141817)

Copy link
Contributor Author

This stack of pull requests is managed by Graphite. Learn more about stacking.

@svkeerthy svkeerthy changed the title Remove Dimension [IR2Vec] Removing Dimension from Embedder::Create Jun 2, 2025
@svkeerthy svkeerthy marked this pull request as ready for review June 2, 2025 21:06
@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-mlgo

Author: S. VenkataKeerthy (svkeerthy)

Changes

This PR removes the necessity to know the dimension of the embeddings while invoking Embedder::Create. Having the Dimension parameter introduces complexities in downstream consumers.

(Tracking issue - #141817)


Full diff: https://github.com/llvm/llvm-project/pull/142486.diff

4 Files Affected:

  • (modified) llvm/docs/MLGO.rst (+1-2)
  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+5-8)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+8-12)
  • (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+5-6)
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 377c2aec44475..4f8fb3f59ca19 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -469,7 +469,6 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
         return;
       }
       const ir2vec::Vocab &Vocabulary = VocabRes.getVocabulary();
-      unsigned Dimension = VocabRes.getDimension();
 
    Note that ``IR2VecVocabAnalysis`` pass is immutable.
 
@@ -481,7 +480,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
       // Assuming F is an llvm::Function&
       // For example, using IR2VecKind::Symbolic:
       Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
-          ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary, Dimension);
+          ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
 
       if (auto Err = EmbOrErr.takeError()) {
         // Handle error in embedder creation
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 288753b3b3b8f..9fd1b0ae8e248 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -84,7 +84,7 @@ class Embedder {
   mutable BBEmbeddingsMap BBVecMap;
   mutable InstEmbeddingsMap InstVecMap;
 
-  Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
+  Embedder(const Function &F, const Vocab &Vocabulary);
 
   /// Helper function to compute embeddings. It generates embeddings for all
   /// the instructions and basic blocks in the function F. Logic of computing
@@ -110,10 +110,8 @@ class Embedder {
   virtual ~Embedder() = default;
 
   /// Factory method to create an Embedder object.
-  static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
-                                                    const Function &F,
-                                                    const Vocab &Vocabulary,
-                                                    unsigned Dimension);
+  static Expected<std::unique_ptr<Embedder>>
+  create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
 
   /// Returns a map containing instructions and the corresponding embeddings for
   /// the function F if it has been computed. If not, it computes the embeddings
@@ -149,9 +147,8 @@ class SymbolicEmbedder : public Embedder {
   void computeEmbeddings(const BasicBlock &BB) const override;
 
 public:
-  SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
-                   unsigned Dimension)
-      : Embedder(F, Vocabulary, Dimension) {
+  SymbolicEmbedder(const Function &F, const Vocab &Vocabulary)
+      : Embedder(F, Vocabulary) {
     FuncVector = Embedding(Dimension, 0);
   }
 };
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 67af44dcac424..490db5fdcdf99 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -59,19 +59,16 @@ AnalysisKey IR2VecVocabAnalysis::Key;
 // Embedder and its subclasses
 //===----------------------------------------------------------------------===//
 
-Embedder::Embedder(const Function &F, const Vocab &Vocabulary,
-                   unsigned Dimension)
-    : F(F), Vocabulary(Vocabulary), Dimension(Dimension),
-      OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
-}
+Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
+    : F(F), Vocabulary(Vocabulary),
+      Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
+      TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
 
-Expected<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
-                                                     const Function &F,
-                                                     const Vocab &Vocabulary,
-                                                     unsigned Dimension) {
+Expected<std::unique_ptr<Embedder>>
+Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
   switch (Mode) {
   case IR2VecKind::Symbolic:
-    return std::make_unique<SymbolicEmbedder>(F, Vocabulary, Dimension);
+    return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
   }
   return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
 }
@@ -286,10 +283,9 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
   assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");
 
   auto Vocab = IR2VecVocabResult.getVocabulary();
-  auto Dim = IR2VecVocabResult.getDimension();
   for (Function &F : M) {
     Expected<std::unique_ptr<Embedder>> EmbOrErr =
-        Embedder::create(IR2VecKind::Symbolic, F, Vocab, Dim);
+        Embedder::create(IR2VecKind::Symbolic, F, Vocab);
     if (auto Err = EmbOrErr.takeError()) {
       handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
         OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 0158038b59b6c..9e47b2cd8bedd 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -28,8 +28,7 @@ namespace {
 
 class TestableEmbedder : public Embedder {
 public:
-  TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
-      : Embedder(F, V, Dim) {}
+  TestableEmbedder(const Function &F, const Vocab &V) : Embedder(F, V) {}
   void computeEmbeddings() const override {}
   void computeEmbeddings(const BasicBlock &BB) const override {}
   using Embedder::lookupVocab;
@@ -50,7 +49,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
   EXPECT_TRUE(static_cast<bool>(Result));
 
   auto *Emb = Result->get();
@@ -66,7 +65,7 @@ TEST(IR2VecTest, CreateInvalidMode) {
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
   // static_cast an invalid int to IR2VecKind
-  auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V, 2);
+  auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
   EXPECT_FALSE(static_cast<bool>(Result));
 
   std::string ErrMsg;
@@ -123,7 +122,7 @@ TEST(IR2VecTest, LookupVocab) {
   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
-  TestableEmbedder E(*F, V, 2);
+  TestableEmbedder E(*F, V);
   auto V_foo = E.lookupVocab("foo");
   EXPECT_EQ(V_foo.size(), 2u);
   EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0));
@@ -190,7 +189,7 @@ struct GetterTestEnv {
     Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
     Ret = ReturnInst::Create(Ctx, Add, BB);
 
-    auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
+    auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
     EXPECT_TRUE(static_cast<bool>(Result));
     Emb = std::move(*Result);
   }

@svkeerthy svkeerthy changed the title [IR2Vec] Removing Dimension from Embedder::Create [NFC][IR2Vec] Removing Dimension from Embedder::Create Jun 2, 2025
Copy link
Member

mtrofin commented Jun 2, 2025

IIUC the main thing is that the dimension is trivially obtainable from the rest of the data (i.e. not so much about downstream-ness, as more about future maintainability - why duplicate information)

Copy link
Contributor Author

That's right. Having to pass dimension would complicate the consumers in a way that they are expected to know the dimension apriori (which in a way necessitates tracking the dimension).

Copy link
Contributor Author

svkeerthy commented Jun 2, 2025

Merge activity

  • Jun 2, 10:03 PM UTC: A user started a stack merge that includes this pull request via Graphite.
  • Jun 2, 10:05 PM UTC: @svkeerthy merged this pull request with Graphite.

@svkeerthy svkeerthy merged commit 741136a into main Jun 2, 2025
14 of 16 checks passed
@svkeerthy svkeerthy deleted the users/svkeerthy/06-02-remove_dimension branch June 2, 2025 22:05
sallto pushed a commit to sallto/llvm-project that referenced this pull request Jun 3, 2025
This PR removes the necessity to know the dimension of the embeddings while invoking `Embedder::Create`. Having the `Dimension` parameter introduces complexities in downstream consumers. 

(Tracking issue - llvm#141817)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants