Skip to content

[NFC][IR2Vec] Removing Dimension from Embedder::Create #142486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions llvm/docs/MLGO.rst
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,6 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
return;
}
const ir2vec::Vocab &Vocabulary = VocabRes.getVocabulary();
unsigned Dimension = VocabRes.getDimension();

Note that ``IR2VecVocabAnalysis`` pass is immutable.

Expand All @@ -481,7 +480,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
// Assuming F is an llvm::Function&
// For example, using IR2VecKind::Symbolic:
Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary, Dimension);
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);

if (auto Err = EmbOrErr.takeError()) {
// Handle error in embedder creation
Expand Down
13 changes: 5 additions & 8 deletions llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Embedder {
mutable BBEmbeddingsMap BBVecMap;
mutable InstEmbeddingsMap InstVecMap;

Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
Embedder(const Function &F, const Vocab &Vocabulary);

/// Helper function to compute embeddings. It generates embeddings for all
/// the instructions and basic blocks in the function F. Logic of computing
Expand All @@ -110,10 +110,8 @@ class Embedder {
virtual ~Embedder() = default;

/// Factory method to create an Embedder object.
static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
const Function &F,
const Vocab &Vocabulary,
unsigned Dimension);
static Expected<std::unique_ptr<Embedder>>
create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);

/// Returns a map containing instructions and the corresponding embeddings for
/// the function F if it has been computed. If not, it computes the embeddings
Expand Down Expand Up @@ -149,9 +147,8 @@ class SymbolicEmbedder : public Embedder {
void computeEmbeddings(const BasicBlock &BB) const override;

public:
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
unsigned Dimension)
: Embedder(F, Vocabulary, Dimension) {
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary)
: Embedder(F, Vocabulary) {
FuncVector = Embedding(Dimension, 0);
}
};
Expand Down
20 changes: 8 additions & 12 deletions llvm/lib/Analysis/IR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,16 @@ AnalysisKey IR2VecVocabAnalysis::Key;
// Embedder and its subclasses
//===----------------------------------------------------------------------===//

Embedder::Embedder(const Function &F, const Vocab &Vocabulary,
unsigned Dimension)
: F(F), Vocabulary(Vocabulary), Dimension(Dimension),
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
}
Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
: F(F), Vocabulary(Vocabulary),
Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}

Expected<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
const Function &F,
const Vocab &Vocabulary,
unsigned Dimension) {
Expected<std::unique_ptr<Embedder>>
Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocabulary, Dimension);
return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
}
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
}
Expand Down Expand Up @@ -286,10 +283,9 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");

auto Vocab = IR2VecVocabResult.getVocabulary();
auto Dim = IR2VecVocabResult.getDimension();
for (Function &F : M) {
Expected<std::unique_ptr<Embedder>> EmbOrErr =
Embedder::create(IR2VecKind::Symbolic, F, Vocab, Dim);
Embedder::create(IR2VecKind::Symbolic, F, Vocab);
if (auto Err = EmbOrErr.takeError()) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
Expand Down
11 changes: 5 additions & 6 deletions llvm/unittests/Analysis/IR2VecTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ namespace {

class TestableEmbedder : public Embedder {
public:
TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
: Embedder(F, V, Dim) {}
TestableEmbedder(const Function &F, const Vocab &V) : Embedder(F, V) {}
void computeEmbeddings() const override {}
void computeEmbeddings(const BasicBlock &BB) const override {}
using Embedder::lookupVocab;
Expand All @@ -50,7 +49,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);

auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_TRUE(static_cast<bool>(Result));

auto *Emb = Result->get();
Expand All @@ -66,7 +65,7 @@ TEST(IR2VecTest, CreateInvalidMode) {
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);

// static_cast an invalid int to IR2VecKind
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V, 2);
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
EXPECT_FALSE(static_cast<bool>(Result));

std::string ErrMsg;
Expand Down Expand Up @@ -123,7 +122,7 @@ TEST(IR2VecTest, LookupVocab) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);

TestableEmbedder E(*F, V, 2);
TestableEmbedder E(*F, V);
auto V_foo = E.lookupVocab("foo");
EXPECT_EQ(V_foo.size(), 2u);
EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0));
Expand Down Expand Up @@ -190,7 +189,7 @@ struct GetterTestEnv {
Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
Ret = ReturnInst::Create(Ctx, Add, BB);

auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_TRUE(static_cast<bool>(Result));
Emb = std::move(*Result);
}
Expand Down
Loading