https://github.com/svkeerthy updated 
https://github.com/llvm/llvm-project/pull/147844

>From bf757c03868bf5e85966440408e41f5343727384 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeer...@google.com>
Date: Wed, 9 Jul 2025 22:44:03 +0000
Subject: [PATCH] IR2Vec Tool Enhancements

---
 llvm/test/tools/llvm-ir2vec/embeddings.ll |  73 +++++++++
 llvm/test/tools/llvm-ir2vec/triplets.ll   |   2 +-
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp    | 185 ++++++++++++++++++++--
 3 files changed, 248 insertions(+), 12 deletions(-)
 create mode 100644 llvm/test/tools/llvm-ir2vec/embeddings.ll

diff --git a/llvm/test/tools/llvm-ir2vec/embeddings.ll 
b/llvm/test/tools/llvm-ir2vec/embeddings.ll
new file mode 100644
index 0000000000000..d5eed749193ac
--- /dev/null
+++ b/llvm/test/tools/llvm-ir2vec/embeddings.ll
@@ -0,0 +1,73 @@
+; RUN: llvm-ir2vec --mode=embeddings 
--ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
 %s | FileCheck %s -check-prefix=CHECK-DEFAULT
+; RUN: llvm-ir2vec --mode=embeddings --level=func 
--ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
 %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL
+; RUN: llvm-ir2vec --mode=embeddings --level=func --function=abc 
--ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
 %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL-ABC
+; RUN: not llvm-ir2vec --mode=embeddings --level=func --function=def 
--ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
 %s 2>&1 | FileCheck %s -check-prefix=CHECK-FUNC-DEF
+; RUN: llvm-ir2vec --mode=embeddings --level=bb 
--ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
 %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL
+; RUN: llvm-ir2vec --mode=embeddings --level=bb --function=abc_repeat 
--ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
 %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL-ABC-REPEAT
+; RUN: llvm-ir2vec --mode=embeddings --level=inst --function=abc_repeat 
--ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
 %s | FileCheck %s -check-prefix=CHECK-INST-LEVEL-ABC-REPEAT
+
+define dso_local noundef float @abc(i32 noundef %a, float noundef %b) #0 {
+entry:
+  %a.addr = alloca i32, align 4
+  %b.addr = alloca float, align 4
+  store i32 %a, ptr %a.addr, align 4
+  store float %b, ptr %b.addr, align 4
+  %0 = load i32, ptr %a.addr, align 4
+  %1 = load i32, ptr %a.addr, align 4
+  %mul = mul nsw i32 %0, %1
+  %conv = sitofp i32 %mul to float
+  %2 = load float, ptr %b.addr, align 4
+  %add = fadd float %conv, %2
+  ret float %add
+}
+
+define dso_local noundef float @abc_repeat(i32 noundef %a, float noundef %b) 
#0 {
+entry:
+  %a.addr = alloca i32, align 4
+  %b.addr = alloca float, align 4
+  store i32 %a, ptr %a.addr, align 4
+  store float %b, ptr %b.addr, align 4
+  %0 = load i32, ptr %a.addr, align 4
+  %1 = load i32, ptr %a.addr, align 4
+  %mul = mul nsw i32 %0, %1
+  %conv = sitofp i32 %mul to float
+  %2 = load float, ptr %b.addr, align 4
+  %add = fadd float %conv, %2
+  ret float %add
+}
+
+; CHECK-DEFAULT: Function: abc
+; CHECK-DEFAULT-NEXT: [ 878.00  889.00  900.00 ]
+; CHECK-DEFAULT-NEXT: Function: abc_repeat
+; CHECK-DEFAULT-NEXT: [ 878.00  889.00  900.00 ]
+
+; CHECK-FUNC-LEVEL: Function: abc 
+; CHECK-FUNC-LEVEL-NEXT: [ 878.00  889.00  900.00 ]
+; CHECK-FUNC-LEVEL-NEXT: Function: abc_repeat 
+; CHECK-FUNC-LEVEL-NEXT: [ 878.00  889.00  900.00 ]
+
+; CHECK-FUNC-LEVEL-ABC: Function: abc
+; CHECK-FUNC-LEVEL-NEXT-ABC:  [ 878.00  889.00  900.00 ]
+
+; CHECK-FUNC-DEF: Error: Function 'def' not found
+
+; CHECK-BB-LEVEL: Function: abc
+; CHECK-BB-LEVEL-NEXT: entry: [ 878.00  889.00  900.00 ]
+; CHECK-BB-LEVEL-NEXT: Function: abc_repeat
+; CHECK-BB-LEVEL-NEXT: entry: [ 878.00  889.00  900.00 ]
+
+; CHECK-BB-LEVEL-ABC-REPEAT: Function: abc_repeat
+; CHECK-BB-LEVEL-ABC-REPEAT-NEXT: entry: [ 878.00  889.00  900.00 ]
+
+; CHECK-INST-LEVEL-ABC-REPEAT: Function: abc_repeat
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %a.addr = alloca i32, align 4 [ 91.00  
92.00  93.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %b.addr = alloca float, align 4 [ 91.00  
92.00  93.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store i32 %a, ptr %a.addr, align 4 [ 97.00 
 98.00  99.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store float %b, ptr %b.addr, align 4 [ 
97.00  98.00  99.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %0 = load i32, ptr %a.addr, align 4 [ 
94.00  95.00  96.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %1 = load i32, ptr %a.addr, align 4 [ 
94.00  95.00  96.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %mul = mul nsw i32 %0, %1 [ 49.00  50.00  
51.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %conv = sitofp i32 %mul to float [ 130.00  
131.00  132.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %2 = load float, ptr %b.addr, align 4 [ 
94.00  95.00  96.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %add = fadd float %conv, %2 [ 40.00  41.00 
 42.00 ]
+; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: ret float %add [ 1.00  2.00  3.00 ]
diff --git a/llvm/test/tools/llvm-ir2vec/triplets.ll 
b/llvm/test/tools/llvm-ir2vec/triplets.ll
index fa5aaa895406f..d1ef5b388e258 100644
--- a/llvm/test/tools/llvm-ir2vec/triplets.ll
+++ b/llvm/test/tools/llvm-ir2vec/triplets.ll
@@ -1,4 +1,4 @@
-; RUN: llvm-ir2vec %s | FileCheck %s -check-prefix=TRIPLETS
+; RUN: llvm-ir2vec --mode=triplets %s | FileCheck %s -check-prefix=TRIPLETS
 
 define i32 @simple_add(i32 %a, i32 %b) {
 entry:
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp 
b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 35e1c995fa4cc..ab2b734da233e 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -9,12 +9,18 @@
 /// \file
 /// This file implements the IR2Vec embedding generation tool.
 ///
-/// Currently supports triplet generation for vocabulary training.
-/// Future updates will support embedding generation using trained vocabulary.
+/// This tool provides two main functionalities:
 ///
-/// Usage: llvm-ir2vec input.bc -o triplets.txt
+/// 1. Triplet Generation Mode (--mode=triplets):
+///    Generates triplets (opcode, type, operands) for vocabulary training.
+///    Usage: llvm-ir2vec --mode=triplets input.bc -o triplets.txt
 ///
-/// TODO: Add embedding generation mode with vocabulary support
+/// 2. Embedding Generation Mode (--mode=embeddings):
+///    Generates IR2Vec embeddings using a trained vocabulary.
+///    Usage: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=vocab.json
+///    --level=func input.bc -o embeddings.txt Levels: --level=inst
+///    (instructions), --level=bb (basic blocks), --level=func (functions)
+///    (See IR2Vec.cpp for more embedding generation options)
 ///
 
//===----------------------------------------------------------------------===//
 
@@ -24,6 +30,8 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IRReader/IRReader.h"
 #include "llvm/Support/CommandLine.h"
@@ -34,7 +42,7 @@
 #include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
-using namespace ir2vec;
+using namespace llvm::ir2vec;
 
 #define DEBUG_TYPE "ir2vec"
 
@@ -50,16 +58,63 @@ static cl::opt<std::string> OutputFilename("o", 
cl::desc("Output filename"),
                                            cl::init("-"),
                                            cl::cat(IR2VecToolCategory));
 
+enum ToolMode {
+  TripletMode,  // Generate triplets for vocabulary training
+  EmbeddingMode // Generate embeddings using trained vocabulary
+};
+
+static cl::opt<ToolMode>
+    Mode("mode", cl::desc("Tool operation mode:"),
+         cl::values(clEnumValN(TripletMode, "triplets",
+                               "Generate triplets for vocabulary training"),
+                    clEnumValN(EmbeddingMode, "embeddings",
+                               "Generate embeddings using trained 
vocabulary")),
+         cl::init(EmbeddingMode), cl::cat(IR2VecToolCategory));
+
+static cl::opt<std::string>
+    FunctionName("function", cl::desc("Process specific function only"),
+                 cl::value_desc("name"), cl::Optional, cl::init(""),
+                 cl::cat(IR2VecToolCategory));
+
+enum EmbeddingLevel {
+  InstructionLevel, // Generate instruction-level embeddings
+  BasicBlockLevel,  // Generate basic block-level embeddings
+  FunctionLevel     // Generate function-level embeddings
+};
+
+static cl::opt<EmbeddingLevel>
+    Level("level", cl::desc("Embedding generation level (for embedding 
mode):"),
+          cl::values(clEnumValN(InstructionLevel, "inst",
+                                "Generate instruction-level embeddings"),
+                     clEnumValN(BasicBlockLevel, "bb",
+                                "Generate basic block-level embeddings"),
+                     clEnumValN(FunctionLevel, "func",
+                                "Generate function-level embeddings")),
+          cl::init(FunctionLevel), cl::cat(IR2VecToolCategory));
+
 namespace {
 
-/// Helper class for collecting IR information and generating triplets
+/// Helper class for collecting IR information and generating embeddings
 class IR2VecTool {
 private:
   Module &M;
+  ModuleAnalysisManager MAM;
+  const Vocabulary *Vocab = nullptr;
 
 public:
   explicit IR2VecTool(Module &M) : M(M) {}
 
+  /// Initialize the IR2Vec vocabulary analysis
+  bool initializeVocabulary() {
+    // Register and run the IR2Vec vocabulary analysis
+    // The vocabulary file path is specified via --ir2vec-vocab-path global
+    // option
+    MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+    MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
+    Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
+    return Vocab->isValid();
+  }
+
   /// Generate triplets for the entire module
   void generateTriplets(raw_ostream &OS) const {
     for (const Function &F : M)
@@ -81,6 +136,70 @@ class IR2VecTool {
     OS << LocalOutput;
   }
 
+  /// Generate embeddings for the entire module
+  void generateEmbeddings(raw_ostream &OS) const {
+    if (!Vocab->isValid()) {
+      OS << "Error: Vocabulary is not valid. IR2VecTool not initialized.\n";
+      return;
+    }
+
+    for (const Function &F : M)
+      generateEmbeddings(F, OS);
+  }
+
+  /// Generate embeddings for a single function
+  void generateEmbeddings(const Function &F, raw_ostream &OS) const {
+    if (F.isDeclaration()) {
+      OS << "Function " << F.getName() << " is a declaration, skipping.\n";
+      return;
+    }
+
+    // Create embedder for this function
+    assert(Vocab->isValid() && "Vocabulary is not valid");
+    auto Emb = Embedder::create(IR2VecKind::Symbolic, F, *Vocab);
+    if (!Emb) {
+      OS << "Error: Failed to create embedder for function " << F.getName()
+         << "\n";
+      return;
+    }
+
+    OS << "Function: " << F.getName() << "\n";
+
+    // Generate embeddings based on the specified level
+    switch (Level) {
+    case FunctionLevel: {
+      Emb->getFunctionVector().print(OS);
+      break;
+    }
+    case BasicBlockLevel: {
+      const auto &BBVecMap = Emb->getBBVecMap();
+      for (const BasicBlock &BB : F) {
+        auto It = BBVecMap.find(&BB);
+        if (It != BBVecMap.end()) {
+          OS << BB.getName() << ":";
+          It->second.print(OS);
+        }
+      }
+      break;
+    }
+    case InstructionLevel: {
+      const auto &InstMap = Emb->getInstVecMap();
+      for (const BasicBlock &BB : F) {
+        for (const Instruction &I : BB) {
+          auto It = InstMap.find(&I);
+          if (It != InstMap.end()) {
+            I.print(OS);
+            It->second.print(OS);
+          }
+        }
+      }
+      break;
+    }
+    }
+
+    // OS << "\n";
+  }
+
 private:
   /// Process a single basic block for triplet generation
   void traverseBasicBlock(const BasicBlock &BB, raw_string_ostream &OS) const {
@@ -105,8 +224,42 @@ class IR2VecTool {
 
 Error processModule(Module &M, raw_ostream &OS) {
   IR2VecTool Tool(M);
-  Tool.generateTriplets(OS);
 
+  if (Mode == EmbeddingMode) {
+    // Initialize vocabulary for embedding generation
+    // Note: Requires --ir2vec-vocab-path option to be set
+    if (!Tool.initializeVocabulary())
+      return createStringError(
+          errc::invalid_argument,
+          "Failed to initialize IR2Vec vocabulary. "
+          "Make sure to specify --ir2vec-vocab-path for embedding mode.");
+
+    if (!FunctionName.empty()) {
+      // Process single function
+      if (const Function *F = M.getFunction(FunctionName))
+        Tool.generateEmbeddings(*F, OS);
+      else
+        return createStringError(errc::invalid_argument,
+                                 "Function '%s' not found",
+                                 FunctionName.c_str());
+    } else {
+      // Process all functions
+      Tool.generateEmbeddings(OS);
+    }
+  } else {
+    // Triplet generation mode - no vocabulary needed
+    if (!FunctionName.empty())
+      // Process single function
+      if (const Function *F = M.getFunction(FunctionName))
+        Tool.generateTriplets(*F, OS);
+      else
+        return createStringError(errc::invalid_argument,
+                                 "Function '%s' not found",
+                                 FunctionName.c_str());
+    else
+      // Process all functions
+      Tool.generateTriplets(OS);
+  }
   return Error::success();
 }
 
@@ -117,11 +270,21 @@ int main(int argc, char **argv) {
   cl::HideUnrelatedOptions(IR2VecToolCategory);
   cl::ParseCommandLineOptions(
       argc, argv,
-      "IR2Vec - Triplet Generation Tool\n"
-      "Generates triplets for vocabulary training from LLVM IR.\n"
-      "Future updates will support embedding generation.\n\n"
+      "IR2Vec - Embedding Generation Tool\n"
+      "Generates embeddings for a given LLVM IR and "
+      "supports triplet generation for vocabulary "
+      "training and embedding generation.\n\n"
       "Usage:\n"
-      "  llvm-ir2vec input.bc -o triplets.txt\n");
+      "  Triplet mode:   llvm-ir2vec --mode=triplets input.bc\n"
+      "  Embedding mode: llvm-ir2vec --mode=embeddings "
+      "--ir2vec-vocab-path=vocab.json --level=func input.bc\n"
+      "  Levels: --level=inst (instructions), --level=bb (basic blocks), "
+      "--level=func (functions)\n");
+
+  // Validate command line options
+  if (Mode == TripletMode && Level != FunctionLevel) {
+    errs() << "Warning: --level option is ignored in triplet mode\n";
+  }
 
   // Parse the input LLVM IR file
   SMDiagnostic Err;

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to