================ @@ -81,6 +136,70 @@ class IR2VecTool { OS << LocalOutput; } + /// Generate embeddings for the entire module + void generateEmbeddings(raw_ostream &OS) const { + if (!Vocab->isValid()) { + OS << "Error: Vocabulary is not valid. IR2VecTool not initialized.\n"; + return; + } + + for (const Function &F : M) + generateEmbeddings(F, OS); + } + + /// Generate embeddings for a single function + void generateEmbeddings(const Function &F, raw_ostream &OS) const { + if (F.isDeclaration()) { + OS << "Function " << F.getName() << " is a declaration, skipping.\n"; + return; + } + + // Create embedder for this function + assert(Vocab->isValid() && "Vocabulary is not valid"); + auto Emb = Embedder::create(IR2VecKind::Symbolic, F, *Vocab); + if (!Emb) { + OS << "Error: Failed to create embedder for function " << F.getName() + << "\n"; + return; + } + + OS << "Function: " << F.getName() << "\n"; + + // Generate embeddings based on the specified level + switch (Level) { + case FunctionLevel: { + Emb->getFunctionVector().print(OS); + break; + } + case BasicBlockLevel: { + const auto &BBVecMap = Emb->getBBVecMap(); + for (const BasicBlock &BB : F) { + auto It = BBVecMap.find(&BB); + if (It != BBVecMap.end()) { + OS << BB.getName() << ":"; + It->second.print(OS); + } + } + break; + } + case InstructionLevel: { + const auto &InstMap = Emb->getInstVecMap(); + for (const BasicBlock &BB : F) { + for (const Instruction &I : BB) { + auto It = InstMap.find(&I); + if (It != InstMap.end()) { + I.print(OS); + It->second.print(OS); + } + } + } + break; + } + } + + // OS << "\n"; ---------------- boomanaiden154 wrote:
Remove? https://github.com/llvm/llvm-project/pull/147844 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits