https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/147585
>From 5eaecce25822a1e4d1aa7e1bb200f6eff7f29234 Mon Sep 17 00:00:00 2001 From: svkeerthy <venkatakeer...@google.com> Date: Mon, 7 Jul 2025 21:30:29 +0000 Subject: [PATCH] [NFC][IR2Vec] Minor refactoring of opcode access in vocabulary --- llvm/include/llvm/Analysis/IR2Vec.h | 9 ++++--- llvm/lib/Analysis/IR2Vec.cpp | 41 ++++++++++++++++------------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index e35793617f7da..2498a211e80e5 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -162,15 +162,18 @@ class Vocabulary { static constexpr unsigned MaxOperandKinds = static_cast<unsigned>(OperandKind::MaxOperandKind); + /// Helper function to get vocabulary key for a given Opcode + static StringRef getVocabKeyForOpcode(unsigned Opcode); + + /// Helper function to get vocabulary key for a given TypeID + static StringRef getVocabKeyForTypeID(Type::TypeID TypeID); + /// Helper function to get vocabulary key for a given OperandKind static StringRef getVocabKeyForOperandKind(OperandKind Kind); /// Helper function to classify an operand into OperandKind static OperandKind getOperandKind(const Value *Op); - /// Helper function to get vocabulary key for a given TypeID - static StringRef getVocabKeyForTypeID(Type::TypeID TypeID); - public: Vocabulary() = default; Vocabulary(VocabVector &&Vocab); diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index b1255c76367b2..c6e1fa32c9ffd 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const { return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)]; } +StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { + assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); +#define HANDLE_INST(NUM, OPCODE, CLASS) \ + if (Opcode == NUM) { \ + return #OPCODE; \ + } +#include "llvm/IR/Instruction.def" +#undef HANDLE_INST + return "UnknownOpcode"; +} + StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) { switch (TypeID) { case Type::VoidTyID: @@ -280,6 +291,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) { default: return "UnknownTy"; } + return "UnknownTy"; } StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) { @@ -316,14 +328,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) { assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds && "Position out of bounds in vocabulary"); // Opcode - if (Pos < MaxOpcodes) { -#define HANDLE_INST(NUM, OPCODE, CLASS) \ - if (Pos == NUM - 1) { \ - return #OPCODE; \ - } -#include "llvm/IR/Instruction.def" -#undef HANDLE_INST - } + if (Pos < MaxOpcodes) + return getVocabKeyForOpcode(Pos + 1); // Type if (Pos < MaxOpcodes + MaxTypeIDs) return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes)); @@ -431,21 +437,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Opcodes std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes, Embedding(Dim, 0)); -#define HANDLE_INST(NUM, OPCODE, CLASS) \ - { \ - auto It = OpcVocab.find(#OPCODE); \ - if (It != OpcVocab.end()) \ - NumericOpcodeEmbeddings[NUM - 1] = It->second; \ - else \ - handleMissingEntity(#OPCODE); \ + for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) { + StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1); + auto It = OpcVocab.find(VocabKey.str()); + if (It != OpcVocab.end()) + NumericOpcodeEmbeddings[Opcode] = It->second; + else + handleMissingEntity(VocabKey.str()); } -#include "llvm/IR/Instruction.def" -#undef HANDLE_INST Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(), NumericOpcodeEmbeddings.end()); - // Handle Types using direct iteration through TypeID enum - // We iterate through all possible TypeID values and map them to embeddings + // Handle Types std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs, Embedding(Dim, 0)); for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) { _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits