https://github.com/svkeerthy updated 
https://github.com/llvm/llvm-project/pull/147585

>From 5eaecce25822a1e4d1aa7e1bb200f6eff7f29234 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeer...@google.com>
Date: Mon, 7 Jul 2025 21:30:29 +0000
Subject: [PATCH] [NFC][IR2Vec] Minor refactoring of opcode access in
 vocabulary

---
 llvm/include/llvm/Analysis/IR2Vec.h |  9 ++++---
 llvm/lib/Analysis/IR2Vec.cpp        | 41 ++++++++++++++++-------------
 2 files changed, 28 insertions(+), 22 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h 
b/llvm/include/llvm/Analysis/IR2Vec.h
index e35793617f7da..2498a211e80e5 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -162,15 +162,18 @@ class Vocabulary {
   static constexpr unsigned MaxOperandKinds =
       static_cast<unsigned>(OperandKind::MaxOperandKind);
 
+  /// Helper function to get vocabulary key for a given Opcode
+  static StringRef getVocabKeyForOpcode(unsigned Opcode);
+
+  /// Helper function to get vocabulary key for a given TypeID
+  static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
+
   /// Helper function to get vocabulary key for a given OperandKind
   static StringRef getVocabKeyForOperandKind(OperandKind Kind);
 
   /// Helper function to classify an operand into OperandKind
   static OperandKind getOperandKind(const Value *Op);
 
-  /// Helper function to get vocabulary key for a given TypeID
-  static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
-
 public:
   Vocabulary() = default;
   Vocabulary(VocabVector &&Vocab);
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index b1255c76367b2..c6e1fa32c9ffd 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const 
Value *Arg) const {
   return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
 }
 
+StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
+  assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+#define HANDLE_INST(NUM, OPCODE, CLASS)                                        
\
+  if (Opcode == NUM) {                                                         
\
+    return #OPCODE;                                                            
\
+  }
+#include "llvm/IR/Instruction.def"
+#undef HANDLE_INST
+  return "UnknownOpcode";
+}
+
 StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
   switch (TypeID) {
   case Type::VoidTyID:
@@ -280,6 +291,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID 
TypeID) {
   default:
     return "UnknownTy";
   }
+  return "UnknownTy";
 }
 
 StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
@@ -316,14 +328,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
   assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
          "Position out of bounds in vocabulary");
   // Opcode
-  if (Pos < MaxOpcodes) {
-#define HANDLE_INST(NUM, OPCODE, CLASS)                                        
\
-  if (Pos == NUM - 1) {                                                        
\
-    return #OPCODE;                                                            
\
-  }
-#include "llvm/IR/Instruction.def"
-#undef HANDLE_INST
-  }
+  if (Pos < MaxOpcodes)
+    return getVocabKeyForOpcode(Pos + 1);
   // Type
   if (Pos < MaxOpcodes + MaxTypeIDs)
     return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
@@ -431,21 +437,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   // Handle Opcodes
   std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
                                                  Embedding(Dim, 0));
-#define HANDLE_INST(NUM, OPCODE, CLASS)                                        
\
-  {                                                                            
\
-    auto It = OpcVocab.find(#OPCODE);                                          
\
-    if (It != OpcVocab.end())                                                  
\
-      NumericOpcodeEmbeddings[NUM - 1] = It->second;                           
\
-    else                                                                       
\
-      handleMissingEntity(#OPCODE);                                            
\
+  for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
+    StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
+    auto It = OpcVocab.find(VocabKey.str());
+    if (It != OpcVocab.end())
+      NumericOpcodeEmbeddings[Opcode] = It->second;
+    else
+      handleMissingEntity(VocabKey.str());
   }
-#include "llvm/IR/Instruction.def"
-#undef HANDLE_INST
   Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
                NumericOpcodeEmbeddings.end());
 
-  // Handle Types using direct iteration through TypeID enum
-  // We iterate through all possible TypeID values and map them to embeddings
+  // Handle Types
   std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
                                                Embedding(Dim, 0));
   for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to