llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlgo

Author: S. VenkataKeerthy (svkeerthy)

<details>
<summary>Changes</summary>

Refactored IR2Vec vocabulary handling to improve code organization and error 
handling. This would help in upcoming PRs related to the IR2Vec tool.

(Tracking issue - #<!-- -->141817)


---
Full diff: https://github.com/llvm/llvm-project/pull/147585.diff


2 Files Affected:

- (modified) llvm/include/llvm/Analysis/IR2Vec.h (+6-3) 
- (modified) llvm/lib/Analysis/IR2Vec.cpp (+24-21) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h 
b/llvm/include/llvm/Analysis/IR2Vec.h
index f5a4e450cf160..176cdaf7b5378 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -163,15 +163,18 @@ class Vocabulary {
   static constexpr unsigned MaxOperandKinds =
       static_cast<unsigned>(OperandKind::MaxOperandKind);
 
+  /// Helper function to get vocabulary key for a given Opcode
+  static StringRef getVocabKeyForOpcode(unsigned Opcode);
+
+  /// Helper function to get vocabulary key for a given TypeID
+  static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
+
   /// Helper function to get vocabulary key for a given OperandKind
   static StringRef getVocabKeyForOperandKind(OperandKind Kind);
 
   /// Helper function to classify an operand into OperandKind
   static OperandKind getOperandKind(const Value *Op);
 
-  /// Helper function to get vocabulary key for a given TypeID
-  static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
-
 public:
   Vocabulary() = default;
   Vocabulary(VocabVector &&Vocab);
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index d3dc2e36fd23e..f97644b93a3d4 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const 
Value *Arg) const {
   return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
 }
 
+StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
+  assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+#define HANDLE_INST(NUM, OPCODE, CLASS)                                        
\
+  if (Opcode == NUM) {                                                         
\
+    return #OPCODE;                                                            
\
+  }
+#include "llvm/IR/Instruction.def"
+#undef HANDLE_INST
+  return "UnknownOpcode";
+}
+
 StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
   switch (TypeID) {
   case Type::VoidTyID:
@@ -280,6 +291,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID 
TypeID) {
   default:
     return "UnknownTy";
   }
+  return "UnknownTy";
 }
 
 // Operand kinds supported by IR2Vec - string mappings
@@ -297,9 +309,9 @@ StringRef 
Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
     OPERAND_KINDS
 #undef OPERAND_KIND
   case Vocabulary::OperandKind::MaxOperandKind:
-    llvm_unreachable("Invalid OperandKind");
+    return "UnknownOperand";
   }
-  llvm_unreachable("Unknown OperandKind");
+  return "UnknownOperand";
 }
 
 #undef OPERAND_KINDS
@@ -332,14 +344,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
   assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
          "Position out of bounds in vocabulary");
   // Opcode
-  if (Pos < MaxOpcodes) {
-#define HANDLE_INST(NUM, OPCODE, CLASS)                                        
\
-  if (Pos == NUM - 1) {                                                        
\
-    return #OPCODE;                                                            
\
-  }
-#include "llvm/IR/Instruction.def"
-#undef HANDLE_INST
-  }
+  if (Pos < MaxOpcodes)
+    return getVocabKeyForOpcode(Pos + 1);
   // Type
   if (Pos < MaxOpcodes + MaxTypeIDs)
     return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
@@ -447,21 +453,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   // Handle Opcodes
   std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
                                                  Embedding(Dim, 0));
-#define HANDLE_INST(NUM, OPCODE, CLASS)                                        
\
-  {                                                                            
\
-    auto It = OpcVocab.find(#OPCODE);                                          
\
-    if (It != OpcVocab.end())                                                  
\
-      NumericOpcodeEmbeddings[NUM - 1] = It->second;                           
\
-    else                                                                       
\
-      handleMissingEntity(#OPCODE);                                            
\
+  for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
+    StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
+    auto It = OpcVocab.find(VocabKey.str());
+    if (It != OpcVocab.end())
+      NumericOpcodeEmbeddings[Opcode] = It->second;
+    else
+      handleMissingEntity(VocabKey.str());
   }
-#include "llvm/IR/Instruction.def"
-#undef HANDLE_INST
   Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
                NumericOpcodeEmbeddings.end());
 
-  // Handle Types using direct iteration through TypeID enum
-  // We iterate through all possible TypeID values and map them to embeddings
+  // Handle Types
   std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
                                                Embedding(Dim, 0));
   for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/147585
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to