pcc created this revision. pcc added a reviewer: vlad.tsyrklevich. Herald added a subscriber: mgrang.
Similarly to CFI on virtual and indirect calls, this implementation tries to use program type information to make the checks as precise as possible. The basic way that it works is as follows, where `C` is the name of the class being defined or the target of a call and the function type is assumed to be `void()`. For virtual calls: - Attach type metadata to the addresses of function pointers in vtables (not the functions themselves) of type `void (B::*)()` for each `B` that is a recursive dynamic base class of `C`, including `C` itself. This type metadata has an annotation that the type is for virtual calls (to distinguish it from the non-virtual case). - At the call site, check that the computed address of the function pointer in the vtable has type `void (C::*)()`. For non-virtual calls: - Attach type metadata to each non-virtual member function whose address can be taken with a member function pointer. The type of a function in class `C` of type `void()` is each of the types `void (B::*)()` where `B` is a most-base class of `C`. A most-base class of `C` is defined as a recursive base class of `C`, including `C` itself, that does not have any bases. - At the call site, check that the function pointer has one of the types `void (B::*)()` where `B` is a most-base class of `C`. TODO: - Implement a fallback for the case where the class type is incomplete at the call site. - Implement non-trapping and cross-DSO support. - Mark this as unsupported with the Microsoft ABI for now. - Write tests. https://reviews.llvm.org/D47567 Files: clang/include/clang/Basic/Sanitizers.def clang/lib/CodeGen/CGVTables.cpp clang/lib/CodeGen/CodeGenModule.cpp clang/lib/CodeGen/CodeGenModule.h clang/lib/CodeGen/ItaniumCXXABI.cpp clang/lib/Driver/SanitizerArgs.cpp
Index: clang/lib/Driver/SanitizerArgs.cpp =================================================================== --- clang/lib/Driver/SanitizerArgs.cpp +++ clang/lib/Driver/SanitizerArgs.cpp @@ -44,7 +44,8 @@ TrappingSupported = (Undefined & ~Vptr) | UnsignedIntegerOverflow | Nullability | LocalBounds | CFI, TrappingDefault = CFI, - CFIClasses = CFIVCall | CFINVCall | CFIDerivedCast | CFIUnrelatedCast, + CFIClasses = + CFIVCall | CFINVCall | CFIMFCall | CFIDerivedCast | CFIUnrelatedCast, CompatibleWithMinimalRuntime = TrappingSupported, }; Index: clang/lib/CodeGen/ItaniumCXXABI.cpp =================================================================== --- clang/lib/CodeGen/ItaniumCXXABI.cpp +++ clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -621,10 +621,26 @@ VTableOffset = Builder.CreateTrunc(VTableOffset, CGF.Int32Ty); VTableOffset = Builder.CreateZExt(VTableOffset, CGM.PtrDiffTy); } + // Compute the address of the virtual function pointer. VTable = Builder.CreateGEP(VTable, VTableOffset); + VTable = Builder.CreateBitCast(VTable, FTy->getPointerTo()->getPointerTo()); + + // Check the address of the function pointer if CFI on member function + // pointers is enabled. + if (CGF.SanOpts.has(SanitizerKind::CFIMFCall)) { + llvm::Metadata *MD = + CGM.CreateMetadataIdentifierForVirtualMemPtrType(QualType(MPT, 0)); + llvm::Value *TypeId = llvm::MetadataAsValue::get(CGF.getLLVMContext(), MD); + + llvm::Value *CastedVTable = Builder.CreateBitCast(VTable, CGF.Int8PtrTy); + llvm::Value *TypeTest = Builder.CreateCall( + CGM.getIntrinsic(llvm::Intrinsic::type_test), {CastedVTable, TypeId}); + + CGF.EmitTrapCheck(TypeTest); + FnVirtual = Builder.GetInsertBlock(); + } // Load the virtual function to call. - VTable = Builder.CreateBitCast(VTable, FTy->getPointerTo()->getPointerTo()); llvm::Value *VirtualFn = Builder.CreateAlignedLoad(VTable, CGF.getPointerAlign(), "memptr.virtualfn"); @@ -636,6 +652,27 @@ llvm::Value *NonVirtualFn = Builder.CreateIntToPtr(FnAsInt, FTy->getPointerTo(), "memptr.nonvirtualfn"); + // Check the function pointer if CFI on member function pointers is enabled. + if (CGF.SanOpts.has(SanitizerKind::CFIMFCall)) { + llvm::Value *Bit = Builder.getFalse(); + llvm::Value *CastedNonVirtualFn = + Builder.CreateBitCast(NonVirtualFn, CGF.Int8PtrTy); + for (const CXXRecordDecl *Base : + CGM.getMostBaseClasses(MPT->getClass()->getAsCXXRecordDecl())) { + llvm::Metadata *MD = + CGM.CreateMetadataIdentifierForType(getContext().getMemberPointerType( + MPT->getPointeeType(), + getContext().getRecordType(Base).getTypePtr())); + llvm::Value *TypeId = llvm::MetadataAsValue::get(CGF.getLLVMContext(), MD); + + llvm::Value *TypeTest = Builder.CreateCall( + CGM.getIntrinsic(llvm::Intrinsic::type_test), {CastedNonVirtualFn, TypeId}); + Bit = Builder.CreateOr(Bit, TypeTest); + } + CGF.EmitTrapCheck(Bit); + FnNonVirtual = Builder.GetInsertBlock(); + } + // We're done. CGF.EmitBlock(FnEnd); llvm::PHINode *CalleePtr = Builder.CreatePHI(FTy->getPointerTo(), 2); Index: clang/lib/CodeGen/CodeGenModule.h =================================================================== --- clang/lib/CodeGen/CodeGenModule.h +++ clang/lib/CodeGen/CodeGenModule.h @@ -503,6 +503,7 @@ /// MDNodes. typedef llvm::DenseMap<QualType, llvm::Metadata *> MetadataTypeMap; MetadataTypeMap MetadataIdMap; + MetadataTypeMap VirtualMetadataIdMap; MetadataTypeMap GeneralizedMetadataIdMap; public: @@ -1229,6 +1230,10 @@ /// internal identifiers). llvm::Metadata *CreateMetadataIdentifierForType(QualType T); + /// Create a metadata identifier that is intended to be used to check virtual + /// calls via a member function pointer. + llvm::Metadata *CreateMetadataIdentifierForVirtualMemPtrType(QualType T); + /// Create a metadata identifier for the generalization of the given type. /// This may either be an MDString (for external identifiers) or a distinct /// unnamed MDNode (for internal identifiers). @@ -1244,6 +1249,9 @@ void AddVTableTypeMetadata(llvm::GlobalVariable *VTable, CharUnits Offset, const CXXRecordDecl *RD); + std::vector<const CXXRecordDecl *> + getMostBaseClasses(const CXXRecordDecl *RD); + /// Get the declaration of std::terminate for the platform. llvm::Constant *getTerminateFn(); @@ -1405,6 +1413,9 @@ void ConstructDefaultFnAttrList(StringRef Name, bool HasOptnone, bool AttrOnCallSite, llvm::AttrBuilder &FuncAttrs); + + llvm::Metadata *CreateMetadataIdentifierImpl(QualType T, MetadataTypeMap &Map, + StringRef Suffix); }; } // end namespace CodeGen Index: clang/lib/CodeGen/CodeGenModule.cpp =================================================================== --- clang/lib/CodeGen/CodeGenModule.cpp +++ clang/lib/CodeGen/CodeGenModule.cpp @@ -1378,14 +1378,45 @@ GV->setLinkage(llvm::GlobalValue::ExternalWeakLinkage); } +std::vector<const CXXRecordDecl *> +CodeGenModule::getMostBaseClasses(const CXXRecordDecl *RD) { + llvm::SetVector<const CXXRecordDecl *> MostBases; + + std::function<void (const CXXRecordDecl *)> CollectMostBases; + CollectMostBases = [&](const CXXRecordDecl *RD) { + if (RD->getNumBases() == 0) + MostBases.insert(RD); + for (const CXXBaseSpecifier &B : RD->bases()) + CollectMostBases(B.getType()->getAsCXXRecordDecl()); + }; + CollectMostBases(RD); + return MostBases.takeVector(); +} + void CodeGenModule::CreateFunctionTypeMetadata(const FunctionDecl *FD, llvm::Function *F) { - // Only if we are checking indirect calls. - if (!LangOpts.Sanitize.has(SanitizerKind::CFIICall)) + auto *MD = dyn_cast<CXXMethodDecl>(FD); + if (MD && !MD->isStatic()) { + if (!LangOpts.Sanitize.has(SanitizerKind::CFIMFCall)) + return; + + // Only functions whose address can be taken with a member function pointer + // need type metadata. + if (MD->isVirtual() || isa<CXXConstructorDecl>(MD) || + isa<CXXDestructorDecl>(MD)) + return; + + for (const CXXRecordDecl *Base : getMostBaseClasses(MD->getParent())) { + llvm::Metadata *Id = + CreateMetadataIdentifierForType(Context.getMemberPointerType( + FD->getType(), Context.getRecordType(Base).getTypePtr())); + F->addTypeMetadata(0, Id); + } return; + } - // Non-static class methods are handled via vtable pointer checks elsewhere. - if (isa<CXXMethodDecl>(FD) && !cast<CXXMethodDecl>(FD)->isStatic()) + // Only if we are checking indirect calls. + if (!LangOpts.Sanitize.has(SanitizerKind::CFIICall)) return; // Additionally, if building with cross-DSO support... @@ -1396,13 +1427,13 @@ return; } - llvm::Metadata *MD = CreateMetadataIdentifierForType(FD->getType()); - F->addTypeMetadata(0, MD); + llvm::Metadata *Id = CreateMetadataIdentifierForType(FD->getType()); + F->addTypeMetadata(0, Id); F->addTypeMetadata(0, CreateMetadataIdentifierGeneralized(FD->getType())); // Emit a hash-based bit set entry for cross-DSO calls. if (CodeGenOpts.SanitizeCfiCrossDso) - if (auto CrossDsoTypeId = CreateCrossDsoCfiTypeId(MD)) + if (auto CrossDsoTypeId = CreateCrossDsoCfiTypeId(Id)) F->addTypeMetadata(0, llvm::ConstantAsMetadata::get(CrossDsoTypeId)); } @@ -4928,15 +4959,18 @@ } } -llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { - llvm::Metadata *&InternalId = MetadataIdMap[T.getCanonicalType()]; +llvm::Metadata * +CodeGenModule::CreateMetadataIdentifierImpl(QualType T, MetadataTypeMap &Map, + StringRef Suffix) { + llvm::Metadata *&InternalId = Map[T.getCanonicalType()]; if (InternalId) return InternalId; if (isExternallyVisible(T->getLinkage())) { std::string OutName; llvm::raw_string_ostream Out(OutName); getCXXABI().getMangleContext().mangleTypeName(T, Out); + Out << Suffix; InternalId = llvm::MDString::get(getLLVMContext(), Out.str()); } else { @@ -4947,6 +4981,15 @@ return InternalId; } +llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { + return CreateMetadataIdentifierImpl(T, MetadataIdMap, ""); +} + +llvm::Metadata * +CodeGenModule::CreateMetadataIdentifierForVirtualMemPtrType(QualType T) { + return CreateMetadataIdentifierImpl(T, VirtualMetadataIdMap, ".virtual"); +} + // Generalize pointer types to a void pointer with the qualifiers of the // originally pointed-to type, e.g. 'const char *' and 'char * const *' // generalize to 'const void *' while 'char *' and 'const char **' generalize to @@ -4980,25 +5023,8 @@ } llvm::Metadata *CodeGenModule::CreateMetadataIdentifierGeneralized(QualType T) { - T = GeneralizeFunctionType(getContext(), T); - - llvm::Metadata *&InternalId = GeneralizedMetadataIdMap[T.getCanonicalType()]; - if (InternalId) - return InternalId; - - if (isExternallyVisible(T->getLinkage())) { - std::string OutName; - llvm::raw_string_ostream Out(OutName); - getCXXABI().getMangleContext().mangleTypeName(T, Out); - Out << ".generalized"; - - InternalId = llvm::MDString::get(getLLVMContext(), Out.str()); - } else { - InternalId = llvm::MDNode::getDistinct(getLLVMContext(), - llvm::ArrayRef<llvm::Metadata *>()); - } - - return InternalId; + return CreateMetadataIdentifierImpl(GeneralizeFunctionType(getContext(), T), + GeneralizedMetadataIdMap, ".generalized"); } /// Returns whether this module needs the "all-vtables" type identifier. Index: clang/lib/CodeGen/CGVTables.cpp =================================================================== --- clang/lib/CodeGen/CGVTables.cpp +++ clang/lib/CodeGen/CGVTables.cpp @@ -1012,41 +1012,56 @@ CharUnits PointerWidth = Context.toCharUnitsFromBits(Context.getTargetInfo().getPointerWidth(0)); - typedef std::pair<const CXXRecordDecl *, unsigned> TypeMetadata; - std::vector<TypeMetadata> TypeMetadatas; - // Create type metadata for each address point. + typedef std::pair<const CXXRecordDecl *, unsigned> AddressPoint; + std::vector<AddressPoint> AddressPoints; for (auto &&AP : VTLayout.getAddressPoints()) - TypeMetadatas.push_back(std::make_pair( + AddressPoints.push_back(std::make_pair( AP.first.getBase(), VTLayout.getVTableOffset(AP.second.VTableIndex) + AP.second.AddressPointIndex)); - // Sort the type metadata for determinism. - llvm::sort(TypeMetadatas.begin(), TypeMetadatas.end(), - [this](const TypeMetadata &M1, const TypeMetadata &M2) { - if (&M1 == &M2) + // Sort the address points for determinism. + llvm::sort(AddressPoints.begin(), AddressPoints.end(), + [this](const AddressPoint &AP1, const AddressPoint &AP2) { + if (&AP1 == &AP2) return false; std::string S1; llvm::raw_string_ostream O1(S1); getCXXABI().getMangleContext().mangleTypeName( - QualType(M1.first->getTypeForDecl(), 0), O1); + QualType(AP1.first->getTypeForDecl(), 0), O1); O1.flush(); std::string S2; llvm::raw_string_ostream O2(S2); getCXXABI().getMangleContext().mangleTypeName( - QualType(M2.first->getTypeForDecl(), 0), O2); + QualType(AP2.first->getTypeForDecl(), 0), O2); O2.flush(); if (S1 < S2) return true; if (S1 != S2) return false; - return M1.second < M2.second; + return AP1.second < AP2.second; }); - for (auto TypeMetadata : TypeMetadatas) - AddVTableTypeMetadata(VTable, PointerWidth * TypeMetadata.second, - TypeMetadata.first); + ArrayRef<VTableComponent> Comps = VTLayout.vtable_components(); + for (auto AP : AddressPoints) { + // Create type metadata for the address point. + AddVTableTypeMetadata(VTable, PointerWidth * AP.second, AP.first); + + // The class associated with each address point could also potentially be + // used for indirect calls via a member function pointer, so we need to + // annotate the address of each function pointer with the appropriate member + // function pointer type. + for (unsigned I = 0; I != Comps.size(); ++I) { + if (Comps[I].getKind() != VTableComponent::CK_FunctionPointer) + continue; + llvm::Metadata *MD = CreateMetadataIdentifierForVirtualMemPtrType( + Context.getMemberPointerType( + Comps[I].getFunctionDecl()->getType(), + Context.getRecordType(AP.first).getTypePtr())); + VTable->addTypeMetadata((PointerWidth * I).getQuantity(), MD); + } + } } Index: clang/include/clang/Basic/Sanitizers.def =================================================================== --- clang/include/clang/Basic/Sanitizers.def +++ clang/include/clang/Basic/Sanitizers.def @@ -104,12 +104,13 @@ SANITIZER("cfi-cast-strict", CFICastStrict) SANITIZER("cfi-derived-cast", CFIDerivedCast) SANITIZER("cfi-icall", CFIICall) +SANITIZER("cfi-mfcall", CFIMFCall) SANITIZER("cfi-unrelated-cast", CFIUnrelatedCast) SANITIZER("cfi-nvcall", CFINVCall) SANITIZER("cfi-vcall", CFIVCall) SANITIZER_GROUP("cfi", CFI, - CFIDerivedCast | CFIICall | CFIUnrelatedCast | CFINVCall | - CFIVCall) + CFIDerivedCast | CFIICall | CFIMFCall | CFIUnrelatedCast | + CFINVCall | CFIVCall) // Safe Stack SANITIZER("safe-stack", SafeStack)
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits