================ @@ -461,6 +463,162 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses( return Changed; } +namespace { + +enum class PointerEncoding { + Rotate, + PACCopyable, + PACNonCopyable, +}; + +bool expandProtectedFieldPtr(Function &Intr) { + Module &M = *Intr.getParent(); + + SmallPtrSet<GlobalValue *, 2> DSsToDeactivate; + SmallPtrSet<Instruction *, 2> LoadsStores; + + Type *Int8Ty = Type::getInt8Ty(M.getContext()); + Type *Int64Ty = Type::getInt64Ty(M.getContext()); + PointerType *PtrTy = PointerType::get(M.getContext(), 0); + + Function *SignIntr = + Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_sign, {}); + Function *AuthIntr = + Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_auth, {}); + + auto *EmuFnTy = FunctionType::get(Int64Ty, {Int64Ty, Int64Ty}, false); + FunctionCallee EmuSignIntr = M.getOrInsertFunction("__emupac_pacda", EmuFnTy); + FunctionCallee EmuAuthIntr = M.getOrInsertFunction("__emupac_autda", EmuFnTy); + + auto CreateSign = [&](IRBuilder<> &B, Value *Val, Value *Disc, + OperandBundleDef DSBundle) { + Function *F = B.GetInsertBlock()->getParent(); + Attribute FSAttr = F->getFnAttribute("target-features"); + if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth")) + return B.CreateCall(SignIntr, {Val, B.getInt32(2), Disc}, DSBundle); + return B.CreateCall(EmuSignIntr, {Val, Disc}, DSBundle); + }; + + auto CreateAuth = [&](IRBuilder<> &B, Value *Val, Value *Disc, + OperandBundleDef DSBundle) { + Function *F = B.GetInsertBlock()->getParent(); + Attribute FSAttr = F->getFnAttribute("target-features"); + if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth")) + return B.CreateCall(AuthIntr, {Val, B.getInt32(2), Disc}, DSBundle); + return B.CreateCall(EmuAuthIntr, {Val, Disc}, DSBundle); + }; + + auto GetDeactivationSymbol = [&](CallInst *Call) -> GlobalValue * { + if (auto Bundle = + Call->getOperandBundle(LLVMContext::OB_deactivation_symbol)) + return cast<GlobalValue>(Bundle->Inputs[0]); + return nullptr; + }; + + for (User *U : Intr.users()) { + auto *Call = cast<CallInst>(U); + auto *DS = GetDeactivationSymbol(Call); + + for (Use &U : Call->uses()) { + if (auto *LI = dyn_cast<LoadInst>(U.getUser())) { + if (isa<PointerType>(LI->getType())) { + LoadsStores.insert(LI); + continue; + } + } + if (auto *SI = dyn_cast<StoreInst>(U.getUser())) { + if (U.getOperandNo() == 1 && + isa<PointerType>(SI->getValueOperand()->getType())) { + LoadsStores.insert(SI); + continue; + } + } + // Comparisons against null cannot be used to recover the original + // pointer so we allow them. + if (auto *CI = dyn_cast<ICmpInst>(U.getUser())) { + if (auto *Op = dyn_cast<Constant>(CI->getOperand(0))) + if (Op->isNullValue()) + continue; + if (auto *Op = dyn_cast<Constant>(CI->getOperand(1))) + if (Op->isNullValue()) + continue; + } + if (DS) + DSsToDeactivate.insert(DS); + } + } + + for (Instruction *I : LoadsStores) { + auto *PointerOperand = isa<StoreInst>(I) + ? cast<StoreInst>(I)->getPointerOperand() + : cast<LoadInst>(I)->getPointerOperand(); + auto *Call = cast<CallInst>(PointerOperand); + + auto *Disc = Call->getArgOperand(1); + bool UseHWEncoding = cast<ConstantInt>(Call->getArgOperand(2))->getZExtValue(); + + GlobalValue *DS = GetDeactivationSymbol(Call); + OperandBundleDef DSBundle("deactivation-symbol", DS); + + if (auto *LI = dyn_cast<LoadInst>(I)) { + IRBuilder<> B(LI->getNextNode()); + auto *LIInt = cast<Instruction>(B.CreatePtrToInt(LI, B.getInt64Ty())); + Value *Auth; + if (UseHWEncoding) { + Auth = CreateAuth(B, LIInt, Disc, DSBundle); + } else { + Auth = B.CreateAdd(LIInt, Disc); + Auth = B.CreateIntrinsic( + Auth->getType(), Intrinsic::fshr, + {Auth, Auth, ConstantInt::get(Auth->getType(), 16)}); + } + LI->replaceAllUsesWith(B.CreateIntToPtr(Auth, B.getPtrTy())); + LIInt->setOperand(0, LI); + } else { + auto *SI = cast<StoreInst>(I); + IRBuilder<> B(SI); + auto *SIValInt = + B.CreatePtrToInt(SI->getValueOperand(), B.getInt64Ty()); + Value *Sign; + if (UseHWEncoding) { + Sign = CreateSign(B, SIValInt, Disc, DSBundle); + } else { + Sign = B.CreateIntrinsic( + SIValInt->getType(), Intrinsic::fshl, + {SIValInt, SIValInt, ConstantInt::get(SIValInt->getType(), 16)}); + Sign = B.CreateSub(Sign, Disc); + } + SI->setOperand(0, B.CreateIntToPtr(Sign, B.getPtrTy())); + } + } + + for (User *U : llvm::make_early_inc_range(Intr.users())) { + auto *Call = cast<CallInst>(U); + auto *Pointer = Call->getArgOperand(0); + + Call->replaceAllUsesWith(Pointer); + Call->eraseFromParent(); + } + + if (!DSsToDeactivate.empty()) { + Constant *Nop = + ConstantExpr::getIntToPtr(ConstantInt::get(Int64Ty, 0xd503201f), PtrTy); ---------------- nikic wrote:
What's this magic constant? https://github.com/llvm/llvm-project/pull/151647 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits