https://github.com/svkeerthy updated 
https://github.com/llvm/llvm-project/pull/149212

>From 68ae9f559439dd1b486713536c925f900afdfbad Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeer...@google.com>
Date: Wed, 16 Jul 2025 21:49:05 +0000
Subject: [PATCH] exposing-new-methods

---
 llvm/include/llvm/Analysis/IR2Vec.h    |  9 ++++
 llvm/lib/Analysis/IR2Vec.cpp           | 20 +++++++-
 llvm/unittests/Analysis/IR2VecTest.cpp | 63 ++++++++++++++++++++++++++
 3 files changed, 90 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h 
b/llvm/include/llvm/Analysis/IR2Vec.h
index 3d7edf08c8807..d87457cac7642 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -170,6 +170,10 @@ class Vocabulary {
   unsigned getDimension() const;
   size_t size() const;
 
+  static size_t expectedSize() {
+    return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
+  }
+
   /// Helper function to get vocabulary key for a given Opcode
   static StringRef getVocabKeyForOpcode(unsigned Opcode);
 
@@ -182,6 +186,11 @@ class Vocabulary {
   /// Helper function to classify an operand into OperandKind
   static OperandKind getOperandKind(const Value *Op);
 
+  /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
+  static unsigned getNumericID(unsigned Opcode);
+  static unsigned getNumericID(Type::TypeID TypeID);
+  static unsigned getNumericID(const Value *Op);
+
   /// Accessors to get the embedding for a given entity.
   const ir2vec::Embedding &operator[](unsigned Opcode) const;
   const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 898bf5b202feb..95f30fd3f4275 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -215,7 +215,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab)
     : Vocab(std::move(Vocab)), Valid(true) {}
 
 bool Vocabulary::isValid() const {
-  return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;
+  return Vocab.size() == Vocabulary::expectedSize() && Valid;
 }
 
 size_t Vocabulary::size() const {
@@ -324,8 +324,24 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const 
Value *Op) {
   return OperandKind::VariableID;
 }
 
+unsigned Vocabulary::getNumericID(unsigned Opcode) {
+  assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+  return Opcode - 1; // Convert to zero-based index
+}
+
+unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {
+  assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+  return MaxOpcodes + static_cast<unsigned>(TypeID);
+}
+
+unsigned Vocabulary::getNumericID(const Value *Op) {
+  unsigned Index = static_cast<unsigned>(getOperandKind(Op));
+  assert(Index < MaxOperandKinds && "Invalid OperandKind");
+  return MaxOpcodes + MaxTypeIDs + Index;
+}
+
 StringRef Vocabulary::getStringKey(unsigned Pos) {
-  assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
+  assert(Pos < Vocabulary::expectedSize() &&
          "Position out of bounds in vocabulary");
   // Opcode
   if (Pos < MaxOpcodes)
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp 
b/llvm/unittests/Analysis/IR2VecTest.cpp
index cb6d633306a81..7c9a5464bfe1d 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -396,6 +396,69 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
   }
 }
 
+TEST(IR2VecVocabularyTest, NumericIDMap) {
+  // Test getNumericID for opcodes
+  EXPECT_EQ(Vocabulary::getNumericID(1u), 0u);
+  EXPECT_EQ(Vocabulary::getNumericID(13u), 12u);
+  EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1);
+
+  // Test getNumericID for Type IDs
+  EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID),
+            MaxOpcodes + static_cast<unsigned>(Type::VoidTyID));
+  EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID),
+            MaxOpcodes + static_cast<unsigned>(Type::HalfTyID));
+  EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID),
+            MaxOpcodes + static_cast<unsigned>(Type::FloatTyID));
+  EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID),
+            MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID));
+  EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID),
+            MaxOpcodes + static_cast<unsigned>(Type::PointerTyID));
+
+  // Test getNumericID for Value operands
+  LLVMContext Ctx;
+  Module M("TestM", Ctx);
+  FunctionType *FTy =
+      FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false);
+  Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", 
M);
+
+  // Test Function operand
+  EXPECT_EQ(Vocabulary::getNumericID(F),
+            MaxOpcodes + MaxTypeIDs + 0u); // Function = 0
+
+  // Test Constant operand
+  Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
+  EXPECT_EQ(Vocabulary::getNumericID(C),
+            MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2
+
+  // Test Pointer operand
+  BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
+  AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
+  EXPECT_EQ(Vocabulary::getNumericID(PtrVal),
+            MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1
+
+  // Test Variable operand (function argument)
+  Argument *Arg = F->getArg(0);
+  EXPECT_EQ(Vocabulary::getNumericID(Arg),
+            MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3
+}
+
+#if GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
+  // Test invalid opcode IDs
+  EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode");
+  EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode");
+
+  // Test invalid type IDs
+  EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)),
+               "Invalid type ID");
+  EXPECT_DEATH(
+      Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
+      "Invalid type ID");
+}
+#endif // NDEBUG
+#endif // GTEST_HAS_DEATH_TEST
+
 TEST(IR2VecVocabularyTest, StringKeyGeneration) {
   EXPECT_EQ(Vocabulary::getStringKey(0), "Ret");
   EXPECT_EQ(Vocabulary::getStringKey(12), "Add");

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to