https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/149212
>From 68ae9f559439dd1b486713536c925f900afdfbad Mon Sep 17 00:00:00 2001 From: svkeerthy <venkatakeer...@google.com> Date: Wed, 16 Jul 2025 21:49:05 +0000 Subject: [PATCH] exposing-new-methods --- llvm/include/llvm/Analysis/IR2Vec.h | 9 ++++ llvm/lib/Analysis/IR2Vec.cpp | 20 +++++++- llvm/unittests/Analysis/IR2VecTest.cpp | 63 ++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 3d7edf08c8807..d87457cac7642 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -170,6 +170,10 @@ class Vocabulary { unsigned getDimension() const; size_t size() const; + static size_t expectedSize() { + return MaxOpcodes + MaxTypeIDs + MaxOperandKinds; + } + /// Helper function to get vocabulary key for a given Opcode static StringRef getVocabKeyForOpcode(unsigned Opcode); @@ -182,6 +186,11 @@ class Vocabulary { /// Helper function to classify an operand into OperandKind static OperandKind getOperandKind(const Value *Op); + /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind + static unsigned getNumericID(unsigned Opcode); + static unsigned getNumericID(Type::TypeID TypeID); + static unsigned getNumericID(const Value *Op); + /// Accessors to get the embedding for a given entity. const ir2vec::Embedding &operator[](unsigned Opcode) const; const ir2vec::Embedding &operator[](Type::TypeID TypeId) const; diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 898bf5b202feb..95f30fd3f4275 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -215,7 +215,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)), Valid(true) {} bool Vocabulary::isValid() const { - return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid; + return Vocab.size() == Vocabulary::expectedSize() && Valid; } size_t Vocabulary::size() const { @@ -324,8 +324,24 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { return OperandKind::VariableID; } +unsigned Vocabulary::getNumericID(unsigned Opcode) { + assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); + return Opcode - 1; // Convert to zero-based index +} + +unsigned Vocabulary::getNumericID(Type::TypeID TypeID) { + assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID"); + return MaxOpcodes + static_cast<unsigned>(TypeID); +} + +unsigned Vocabulary::getNumericID(const Value *Op) { + unsigned Index = static_cast<unsigned>(getOperandKind(Op)); + assert(Index < MaxOperandKinds && "Invalid OperandKind"); + return MaxOpcodes + MaxTypeIDs + Index; +} + StringRef Vocabulary::getStringKey(unsigned Pos) { - assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds && + assert(Pos < Vocabulary::expectedSize() && "Position out of bounds in vocabulary"); // Opcode if (Pos < MaxOpcodes) diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index cb6d633306a81..7c9a5464bfe1d 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -396,6 +396,69 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) { } } +TEST(IR2VecVocabularyTest, NumericIDMap) { + // Test getNumericID for opcodes + EXPECT_EQ(Vocabulary::getNumericID(1u), 0u); + EXPECT_EQ(Vocabulary::getNumericID(13u), 12u); + EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1); + + // Test getNumericID for Type IDs + EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID), + MaxOpcodes + static_cast<unsigned>(Type::VoidTyID)); + EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID), + MaxOpcodes + static_cast<unsigned>(Type::HalfTyID)); + EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID), + MaxOpcodes + static_cast<unsigned>(Type::FloatTyID)); + EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID), + MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID)); + EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID), + MaxOpcodes + static_cast<unsigned>(Type::PointerTyID)); + + // Test getNumericID for Value operands + LLVMContext Ctx; + Module M("TestM", Ctx); + FunctionType *FTy = + FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", M); + + // Test Function operand + EXPECT_EQ(Vocabulary::getNumericID(F), + MaxOpcodes + MaxTypeIDs + 0u); // Function = 0 + + // Test Constant operand + Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42); + EXPECT_EQ(Vocabulary::getNumericID(C), + MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2 + + // Test Pointer operand + BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F); + AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB); + EXPECT_EQ(Vocabulary::getNumericID(PtrVal), + MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1 + + // Test Variable operand (function argument) + Argument *Arg = F->getArg(0); + EXPECT_EQ(Vocabulary::getNumericID(Arg), + MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3 +} + +#if GTEST_HAS_DEATH_TEST +#ifndef NDEBUG +TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) { + // Test invalid opcode IDs + EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode"); + EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode"); + + // Test invalid type IDs + EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)), + "Invalid type ID"); + EXPECT_DEATH( + Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)), + "Invalid type ID"); +} +#endif // NDEBUG +#endif // GTEST_HAS_DEATH_TEST + TEST(IR2VecVocabularyTest, StringKeyGeneration) { EXPECT_EQ(Vocabulary::getStringKey(0), "Ret"); EXPECT_EQ(Vocabulary::getStringKey(12), "Add"); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits