https://github.com/bwendling updated 
https://github.com/llvm/llvm-project/pull/141846

>From bfdf16cee252dc6491f15c0715c7a39538fbf120 Mon Sep 17 00:00:00 2001
From: Bill Wendling <mo...@google.com>
Date: Wed, 23 Apr 2025 13:14:24 -0700
Subject: [PATCH 1/2] [Clang][attr] Add cfi_salt attribute

The new 'cfi_salt' attribute is used to differentiate between two
functions with the same type signature. It's applied to the function
declaration, function definition, and function typedefs of the function
to ensure the "salt" is passed through to code generation.

Simple example:

    #define __cfi_salt __attribute__((cfi_salt("pepper")))

    extern int foo(void) __cfi_salt;
    int f1_salt(void) __cfi_salt { /* ... */ }

Example of using this in a structure:

    #define __cfi_salt __attribute__((cfi_salt("pepper")))

    // Convenient typedefs to avoid nested declarator syntax.
    typedef int (*fptr_t)(void);
    typedef int (*fptr_salted_t)(void) __cfi_salt;

    struct widget_generator {
      fptr_t init;
      fptr_salted_t exec;
      fptr_t teardown;
    };

    // 1st .c file:
    static int internal_init(void) { /* ... */ }
    static int internal_salted_exec(void) __cfi_salt { /* ... */ }
    static int internal_teardown(void) { /* ... */ }

    static struct widget_generator _generator = {
      .init = internal_init,
      .exec = internal_salted_exec,
      .teardown = internal_teardown,
    }

    struct widget_generator *widget_gen = _generator;

    // 2nd .c file:
    int generate_a_widget(void) {
      int ret;

      // Called with non-salted CFI.
      ret = widget_gen.init();
      if (ret)
        return ret;

      // Called with salted CFI.
      ret = widget_gen.exec();
      if (ret)
        return ret;

      // Called with non-salted CFI.
      return widget_gen.teardown();
    }

Link: https://github.com/ClangBuiltLinux/linux/issues/1736
Link: https://github.com/KSPP/linux/issues/365
Signed-off-by: Bill Wendling <mo...@google.com>
---
 clang/include/clang/Basic/Attr.td     |   7 +
 clang/include/clang/Basic/AttrDocs.td |  58 ++++++++
 clang/lib/AST/TypePrinter.cpp         |   3 +
 clang/lib/CodeGen/CGCall.h            |   7 +
 clang/lib/CodeGen/CGExpr.cpp          |   5 +-
 clang/lib/CodeGen/CGExprCXX.cpp       |  14 +-
 clang/lib/CodeGen/CodeGenFunction.cpp |  35 ++++-
 clang/lib/CodeGen/CodeGenModule.cpp   |  15 +-
 clang/lib/CodeGen/CodeGenModule.h     |   2 +-
 clang/lib/Sema/SemaDeclAttr.cpp       |  27 ++++
 clang/test/CodeGen/kcfi-salt.c        | 190 ++++++++++++++++++++++++++
 11 files changed, 345 insertions(+), 18 deletions(-)
 create mode 100644 clang/test/CodeGen/kcfi-salt.c

diff --git a/clang/include/clang/Basic/Attr.td 
b/clang/include/clang/Basic/Attr.td
index db02449a3dd12..b3a35ec3a16a4 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -3830,6 +3830,13 @@ def CFICanonicalJumpTable : InheritableAttr {
   let SimpleHandler = 1;
 }
 
+def CFISalt : DeclOrTypeAttr {
+  let Spellings = [Clang<"cfi_salt">];
+  let Args = [StringArgument<"Salt">];
+  let Subjects = SubjectList<[Function, Field, Var, TypedefName], ErrorDiag>;
+  let Documentation = [CFISaltDocs];
+}
+
 // C/C++ Thread safety attributes (e.g. for deadlock, data race checking)
 // Not all of these attributes will be given a [[]] spelling. The attributes
 // which require access to function parameter names cannot use the [[]] 
spelling
diff --git a/clang/include/clang/Basic/AttrDocs.td 
b/clang/include/clang/Basic/AttrDocs.td
index 65d66dd398ad1..4c3a488076c54 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -3639,6 +3639,64 @@ make the function's CFI jump table canonical. See 
:ref:`the CFI documentation
   }];
 }
 
+def CFISaltDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+Use ``__attribute__((cfi_salt("<salt>")))`` on a function declaration, function
+definition, or typedef to help distinguish CFI hashes between functions with
+the same type signature.
+
+Example use:
+
+.. code-block:: c
+
+  // .h file:
+  #define __cfi_salt __attribute__((cfi_salt("vogon")))
+
+  // Convenient typedefs to avoid nested declarator syntax.
+  typedef int (*fptr_t)(void); // Non-salted function call.
+  typedef int (*fptr_salted_t)(void) __cfi_salt;
+
+  struct widget_generator {
+    fptr_t init;
+    fptr_salted_t exec;
+    fptr_t teardown;
+  };
+
+  // 1st .c file:
+  static int internal_init(void) { /* ... */ }
+  static int internal_salted_exec(void) __cfi_salt { /* ... */ }
+  static int internal_teardown(void) { /* ... */ }
+
+  static struct widget_generator _generator = {
+    .init = internal_init,
+    .exec = internal_salted_exec,
+    .teardown = internal_teardown,
+  }
+
+  struct widget_generator *widget_gen = _generator;
+
+  // 2nd .c file:
+  int generate_a_widget(void) {
+    int ret;
+
+    // Called with non-salted CFI.
+    ret = widget_gen.init();
+    if (ret)
+      return ret;
+
+    // Called with salted CFI.
+    ret = widget_gen.exec();
+    if (ret)
+      return ret;
+
+    // Called with non-salted CFI.
+    return widget_gen.teardown();
+  }
+
+  }];
+}
+
 def DocCatTypeSafety : DocumentationCategory<"Type Safety Checking"> {
   let Content = [{
 Clang supports additional attributes to enable checking type safety properties
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 4793ef38c2c46..b241892f3f836 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -2099,6 +2099,9 @@ void TypePrinter::printAttributedAfter(const 
AttributedType *T,
   case attr::ExtVectorType:
     OS << "ext_vector_type";
     break;
+  case attr::CFISalt:
+    OS << "cfi_salt(\"" << cast<CFISaltAttr>(T->getAttr())->getSalt() << "\")";
+    break;
   }
   OS << "))";
 }
diff --git a/clang/lib/CodeGen/CGCall.h b/clang/lib/CodeGen/CGCall.h
index 0b4e3f9cb0365..f54dd64d182e5 100644
--- a/clang/lib/CodeGen/CGCall.h
+++ b/clang/lib/CodeGen/CGCall.h
@@ -96,6 +96,8 @@ class CGCallee {
     VirtualInfoStorage VirtualInfo;
   };
 
+  QualType ExprType;
+
   explicit CGCallee(SpecialKind kind) : KindOrFunctionPointer(kind) {}
 
   CGCallee(const FunctionDecl *builtinDecl, unsigned builtinID)
@@ -221,6 +223,11 @@ class CGCallee {
     return VirtualInfo.FTy;
   }
 
+  /// Retain the full type of the callee before canonicalization. This may have
+  /// attributes important for later processing (e.g. KCFI).
+  QualType getExprType() const { return ExprType; }
+  void setExprType(QualType Ty) { ExprType = Ty; }
+
   /// If this is a delayed callee computation of some sort, prepare
   /// a concrete callee.
   CGCallee prepareConcreteCallee(CodeGenFunction &CGF) const;
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index bae28b45afaa3..3811fa6a875f8 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -6260,12 +6260,13 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType,
           !cast<FunctionDecl>(TargetDecl)->isImmediateFunction()) &&
          "trying to emit a call to an immediate function");
 
+  CGCallee Callee = OrigCallee;
+  Callee.setExprType(CalleeType);
+
   CalleeType = getContext().getCanonicalType(CalleeType);
 
   auto PointeeType = cast<PointerType>(CalleeType)->getPointeeType();
 
-  CGCallee Callee = OrigCallee;
-
   if (SanOpts.has(SanitizerKind::Function) &&
       (!TargetDecl || !isa<FunctionDecl>(TargetDecl)) &&
       !isa<FunctionNoProtoType>(PointeeType)) {
diff --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp
index 024254b0affe4..471c1b9a58cdb 100644
--- a/clang/lib/CodeGen/CGExprCXX.cpp
+++ b/clang/lib/CodeGen/CGExprCXX.cpp
@@ -470,16 +470,14 @@ CodeGenFunction::EmitCXXMemberPointerCallExpr(const 
CXXMemberCallExpr *E,
 
   // Ask the ABI to load the callee.  Note that This is modified.
   llvm::Value *ThisPtrForCall = nullptr;
-  CGCallee Callee =
-    CGM.getCXXABI().EmitLoadOfMemberFunctionPointer(*this, BO, This,
-                                             ThisPtrForCall, MemFnPtr, MPT);
-
-  CallArgList Args;
-
-  QualType ThisType =
-    getContext().getPointerType(getContext().getTagDeclType(RD));
+  CGCallee Callee = CGM.getCXXABI().EmitLoadOfMemberFunctionPointer(
+      *this, BO, This, ThisPtrForCall, MemFnPtr, MPT);
+  Callee.setExprType(MemFnExpr->getType());
 
   // Push the this ptr.
+  QualType ThisType =
+      getContext().getPointerType(getContext().getTagDeclType(RD));
+  CallArgList Args;
   Args.add(RValue::get(ThisPtrForCall), ThisType);
 
   RequiredArgs required = RequiredArgs::forPrototypePlus(FPT, 1);
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp 
b/clang/lib/CodeGen/CodeGenFunction.cpp
index e2357563f7d56..3902b0e0b26c7 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -2887,10 +2887,37 @@ void 
CodeGenFunction::EmitSanitizerStatReport(llvm::SanitizerStatKind SSK) {
 
 void CodeGenFunction::EmitKCFIOperandBundle(
     const CGCallee &Callee, SmallVectorImpl<llvm::OperandBundleDef> &Bundles) {
-  const FunctionProtoType *FP =
-      Callee.getAbstractInfo().getCalleeFunctionProtoType();
-  if (FP)
-    Bundles.emplace_back("kcfi", CGM.CreateKCFITypeId(FP->desugar()));
+  const CGCalleeInfo &CI = Callee.getAbstractInfo();
+  const FunctionProtoType *FP = CI.getCalleeFunctionProtoType();
+  if (!FP)
+    return;
+
+  const AttributedType *AT = nullptr;
+  GlobalDecl GD = CI.getCalleeDecl();
+  StringRef Salt;
+
+  auto GetAttributedType = [](QualType Ty) {
+    if (const auto *ET = dyn_cast<ElaboratedType>(Ty))
+      if (const auto *TT = dyn_cast<TypedefType>(ET->desugar()))
+        return dyn_cast<AttributedType>(TT->desugar());
+
+    return dyn_cast<AttributedType>(Ty);
+  };
+
+  if (GD) {
+    if (const auto *D = dyn_cast<ValueDecl>(GD.getDecl()))
+      AT = GetAttributedType(D->getType());
+  } else {
+    QualType ExprTy = Callee.getExprType();
+    if (!ExprTy.isNull())
+      AT = GetAttributedType(ExprTy);
+  }
+
+  if (AT)
+    if (const auto *CFISalt = dyn_cast<CFISaltAttr>(AT->getAttr()))
+      Salt = CFISalt->getSalt();
+
+  Bundles.emplace_back("kcfi", CGM.CreateKCFITypeId(FP->desugar(), Salt));
 }
 
 llvm::Value *
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp 
b/clang/lib/CodeGen/CodeGenModule.cpp
index 6d2c705338ecf..ad2e4a714deba 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -2221,7 +2221,7 @@ llvm::ConstantInt 
*CodeGenModule::CreateCrossDsoCfiTypeId(llvm::Metadata *MD) {
   return llvm::ConstantInt::get(Int64Ty, llvm::MD5Hash(MDS->getString()));
 }
 
-llvm::ConstantInt *CodeGenModule::CreateKCFITypeId(QualType T) {
+llvm::ConstantInt *CodeGenModule::CreateKCFITypeId(QualType T, StringRef Salt) 
{
   if (auto *FnType = T->getAs<FunctionProtoType>())
     T = getContext().getFunctionType(
         FnType->getReturnType(), FnType->getParamTypes(),
@@ -2232,6 +2232,9 @@ llvm::ConstantInt 
*CodeGenModule::CreateKCFITypeId(QualType T) {
   getCXXABI().getMangleContext().mangleCanonicalTypeName(
       T, Out, getCodeGenOpts().SanitizeCfiICallNormalizeIntegers);
 
+  if (!Salt.empty())
+    Out << "." << Salt;
+
   if (getCodeGenOpts().SanitizeCfiICallNormalizeIntegers)
     Out << ".normalized";
 
@@ -2898,9 +2901,15 @@ void 
CodeGenModule::createFunctionTypeMetadataForIcall(const FunctionDecl *FD,
 void CodeGenModule::setKCFIType(const FunctionDecl *FD, llvm::Function *F) {
   llvm::LLVMContext &Ctx = F->getContext();
   llvm::MDBuilder MDB(Ctx);
+  llvm::StringRef Salt;
+
+  if (const auto *AT = dyn_cast<AttributedType>(FD->getType()))
+    if (const auto *CFISalt = dyn_cast<CFISaltAttr>(AT->getAttr()))
+      Salt = CFISalt->getSalt();
+
   F->setMetadata(llvm::LLVMContext::MD_kcfi_type,
-                 llvm::MDNode::get(
-                     Ctx, 
MDB.createConstant(CreateKCFITypeId(FD->getType()))));
+                 llvm::MDNode::get(Ctx, MDB.createConstant(CreateKCFITypeId(
+                                            FD->getType(), Salt))));
 }
 
 static bool allowKCFIIdentifier(StringRef Name) {
diff --git a/clang/lib/CodeGen/CodeGenModule.h 
b/clang/lib/CodeGen/CodeGenModule.h
index 1db5c3bc4e4ef..5081e1f042546 100644
--- a/clang/lib/CodeGen/CodeGenModule.h
+++ b/clang/lib/CodeGen/CodeGenModule.h
@@ -1616,7 +1616,7 @@ class CodeGenModule : public CodeGenTypeCache {
   llvm::ConstantInt *CreateCrossDsoCfiTypeId(llvm::Metadata *MD);
 
   /// Generate a KCFI type identifier for T.
-  llvm::ConstantInt *CreateKCFITypeId(QualType T);
+  llvm::ConstantInt *CreateKCFITypeId(QualType T, StringRef Salt);
 
   /// Create a metadata identifier for the given type. This may either be an
   /// MDString (for external identifiers) or a distinct unnamed MDNode (for
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 54bac40982eda..b7bd1fd48cd0a 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -6671,6 +6671,29 @@ static void handleEnforceTCBAttr(Sema &S, Decl *D, const 
ParsedAttr &AL) {
   D->addAttr(AttrTy::Create(S.Context, Argument, AL));
 }
 
+static void handleCFISaltAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
+  StringRef Argument;
+  if (!S.checkStringLiteralArgumentAttr(AL, 0, Argument))
+    return;
+
+  auto *Attr = CFISaltAttr::Create(S.Context, Argument, AL);
+
+  if (auto *DD = dyn_cast<ValueDecl>(D)) {
+    QualType Ty = DD->getType();
+    if (!Ty->getAs<AttributedType>())
+      DD->setType(S.Context.getAttributedType(Attr, Ty, Ty));
+  } else if (auto *TD = dyn_cast<TypedefDecl>(D)) {
+    TypeSourceInfo *TSI = TD->getTypeSourceInfo();
+    QualType Ty = TD->getUnderlyingType();
+
+    if (!Ty->hasAttr(attr::CFISalt))
+      TD->setModedTypeSourceInfo(TSI,
+                                 S.Context.getAttributedType(Attr, Ty, Ty));
+  } else {
+    llvm_unreachable("Unexpected Decl argument type!");
+  }
+}
+
 template <typename AttrTy, typename ConflictingAttrTy>
 static AttrTy *mergeEnforceTCBAttrImpl(Sema &S, Decl *D, const AttrTy &AL) {
   // Check if the new redeclaration has different leaf-ness in the same TCB.
@@ -7473,6 +7496,10 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, 
const ParsedAttr &AL,
     handleCountedByAttrField(S, D, AL);
     break;
 
+  case ParsedAttr::AT_CFISalt:
+    handleCFISaltAttr(S, D, AL);
+    break;
+
   // Microsoft attributes:
   case ParsedAttr::AT_LayoutVersion:
     handleLayoutVersion(S, D, AL);
diff --git a/clang/test/CodeGen/kcfi-salt.c b/clang/test/CodeGen/kcfi-salt.c
new file mode 100644
index 0000000000000..7728e0525b3ee
--- /dev/null
+++ b/clang/test/CodeGen/kcfi-salt.c
@@ -0,0 +1,190 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm -fsanitize=kcfi 
-o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm -fsanitize=kcfi 
-x c++ -o - %s | FileCheck %s --check-prefixes=CHECK,MEMBER
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm -fsanitize=kcfi 
-fpatchable-function-entry-offset=3 -o - %s | FileCheck %s 
--check-prefixes=CHECK,OFFSET
+
+// Note that the interleving of functions, which normally would be in sequence,
+// is due to the fact that Clang outputs them in a non-sequential order.
+
+#if !__has_feature(kcfi)
+#error Missing kcfi?
+#endif
+
+#define __cfi_salt __attribute__((cfi_salt("pepper")))
+
+typedef int (*fn_t)(void);
+typedef int __cfi_salt (*fn_salt_t)(void);
+
+typedef unsigned int (*ufn_t)(void);
+typedef unsigned int __cfi_salt (*ufn_salt_t)(void);
+
+/// Must emit __kcfi_typeid symbols for address-taken function declarations
+// CHECK: module asm ".weak __kcfi_typeid_[[F4:[a-zA-Z0-9_]+]]"
+// CHECK: module asm ".set __kcfi_typeid_[[F4]], [[#%d,LOW_SODIUM_HASH:]]"
+// CHECK: module asm ".weak __kcfi_typeid_[[F4_SALT:[a-zA-Z0-9_]+]]"
+// CHECK: module asm ".set __kcfi_typeid_[[F4_SALT]], [[#%d,ASM_SALTY_HASH:]]"
+
+/// Must not __kcfi_typeid symbols for non-address-taken declarations
+// CHECK-NOT: module asm ".weak __kcfi_typeid_{{f6|_Z2f6v}}"
+
+int f1(void);
+int f1_salt(void) __cfi_salt;
+
+unsigned int f2(void);
+unsigned int f2_salt(void) __cfi_salt;
+
+static int f3(void);
+static int f3_salt(void) __cfi_salt;
+
+extern int f4(void);
+extern int f4_salt(void) __cfi_salt;
+
+static int f5(void);
+static int f5_salt(void) __cfi_salt;
+
+extern int f6(void);
+extern int f6_salt(void) __cfi_salt;
+
+struct cfi_struct {
+  fn_t __cfi_salt fptr;
+  fn_salt_t td_fptr;
+};
+
+int f7_salt(struct cfi_struct *ptr);
+int f7_typedef_salt(struct cfi_struct *ptr);
+
+// CHECK-LABEL: @{{__call|_Z6__callPFivE}}
+// CHECK:         call{{.*}} i32
+// CHECK-NOT:     "kcfi"
+// CHECK-SAME:    ()
+int __call(fn_t f) __attribute__((__no_sanitize__("kcfi"))) {
+  return f();
+}
+
+// CHECK-LABEL: @{{call|_Z4callPFivE}}
+// CHECK:         call{{.*}} i32 %{{.}}(){{.*}} [ "kcfi"(i32 
[[#LOW_SODIUM_HASH]]) ]
+// CHECK-LABEL: @{{call_salt|_Z9call_saltPFivE}}
+// CHECK:         call{{.*}} i32 %{{.}}(){{.*}} [ "kcfi"(i32 
[[#%d,SALTY_HASH:]]) ]
+// CHECK-LABEL: @{{call_salt_ty|_Z12call_salt_tyPFivE}}
+// CHECK:         call{{.*}} i32 %{{.}}(){{.*}} [ "kcfi"(i32 [[#SALTY_HASH]]) ]
+int call(fn_t f) { return f(); }
+int call_salt(fn_t __cfi_salt f) { return f(); }
+int call_salt_ty(fn_salt_t f) { return f(); }
+
+// CHECK-LABEL: @{{ucall|_Z5ucallPFjvE}}
+// CHECK:         call{{.*}} i32 %{{.}}(){{.*}} [ "kcfi"(i32 
[[#%d,LOW_SODIUM_UHASH:]]) ]
+// CHECK-LABEL: @{{ucall_salt|_Z10ucall_saltPFjvE}}
+// CHECK:         call{{.*}} i32 %{{.}}(){{.*}} [ "kcfi"(i32 
[[#%d,SALTY_UHASH:]]) ]
+// CHECK-LABEL: @{{ucall_salt_ty|_Z13ucall_salt_tyPFjvE}}
+// CHECK:         call{{.*}} i32 %{{.}}(){{.*}} [ "kcfi"(i32 [[#SALTY_UHASH]]) 
]
+unsigned int ucall(ufn_t f) { return f(); }
+unsigned int ucall_salt(ufn_t __cfi_salt f) { return f(); }
+unsigned int ucall_salt_ty(ufn_salt_t f) { return f(); }
+
+int test1(struct cfi_struct *ptr) {
+  return call(f1) +
+         call_salt(f1_salt) +
+         call_salt_ty(f1_salt) +
+
+         __call((fn_t)f2) +
+         __call((fn_t)f2_salt) +
+
+         ucall(f2) +
+         ucall_salt(f2_salt) +
+         ucall_salt_ty(f2_salt) +
+
+         call(f3) +
+         call_salt(f3_salt) +
+         call_salt_ty(f3_salt) +
+
+         call(f4) +
+         call_salt(f4_salt) +
+         call_salt_ty(f4_salt) +
+
+         f5() +
+         f5_salt() +
+
+         f6() +
+         f6_salt() +
+
+         f7_salt(ptr) +
+         f7_typedef_salt(ptr);
+}
+
+// CHECK-LABEL: define dso_local{{.*}} i32 @{{f1|_Z2f1v}}(){{.*}} !kcfi_type
+// CHECK-SAME:  ![[#LOW_SODIUM_TYPE:]]
+// CHECK-LABEL: define dso_local{{.*}} i32 @{{f1_salt|_Z7f1_saltv}}(){{.*}} 
!kcfi_type
+// CHECK-SAME:  ![[#SALTY_TYPE:]]
+int f1(void) { return 0; }
+int f1_salt(void) __cfi_salt { return 0; }
+
+// CHECK-LABEL: define dso_local{{.*}} i32 @{{f2|_Z2f2v}}(){{.*}} !kcfi_type
+// CHECK-SAME:  ![[#LOW_SODIUM_UTYPE:]]
+// CHECK: define dso_local{{.*}} i32 @{{f2_salt|_Z7f2_saltv}}(){{.*}} 
!kcfi_type
+// CHECK-SAME:  ![[#SALTY_UTYPE:]]
+unsigned int f2(void) { return 2; }
+unsigned int f2_salt(void) __cfi_salt { return 2; }
+
+// CHECK-LABEL: define internal{{.*}} i32 @{{f3|_ZL2f3v}}(){{.*}} !kcfi_type
+// CHECK-SAME:  ![[#LOW_SODIUM_TYPE]]
+// CHECK-LABEL: define internal{{.*}} i32 @{{f3_salt|_ZL7f3_saltv}}(){{.*}} 
!kcfi_type
+// CHECK-SAME:  ![[#SALTY_TYPE]]
+static int f3(void) { return 1; }
+static int f3_salt(void) __cfi_salt { return 1; }
+
+// CHECK: declare !kcfi_type ![[#LOW_SODIUM_TYPE]]{{.*}} i32 @[[F4]]()
+// CHECK: declare !kcfi_type ![[#SALTY_TYPE]]{{.*}} i32 @[[F4_SALT]]()
+
+/// Must not emit !kcfi_type for non-address-taken local functions
+// CHECK-LABEL: define internal{{.*}} i32 @{{f5|_ZL2f5v}}()
+// CHECK-NOT:   !kcfi_type
+// CHECK-SAME:  {
+// CHECK-LABEL: define internal{{.*}} i32 @{{f5_salt|_ZL7f5_saltv}}()
+// CHECK-NOT:   !kcfi_type
+// CHECK-SAME:  {
+static int f5(void) { return 2; }
+static int f5_salt(void) __cfi_salt { return 2; }
+
+// CHECK: declare !kcfi_type ![[#LOW_SODIUM_TYPE]]{{.*}} i32 @{{f6|_Z2f6v}}()
+// CHECK: declare !kcfi_type ![[#SALTY_TYPE]]{{.*}} i32 
@{{f6_salt|_Z7f6_saltv}}()
+
+// CHECK-LABEL: @{{f7_salt|_Z7f7_saltP10cfi_struct}}
+// CHECK:         call{{.*}} i32 %{{.*}}() [ "kcfi"(i32 [[#SALTY_HASH]]) ]
+// CHECK-LABEL: @{{f7_typedef_salt|_Z15f7_typedef_saltP10cfi_struct}}
+// CHECK:         call{{.*}} i32 %{{.*}}() [ "kcfi"(i32 [[#SALTY_HASH]]) ]
+int f7_salt(struct cfi_struct *ptr) { return ptr->fptr(); }
+int f7_typedef_salt(struct cfi_struct *ptr) { return ptr->td_fptr(); }
+
+#ifdef __cplusplus
+// MEMBER-LABEL: define dso_local void @_Z16test_member_callv() #0 !kcfi_type
+// MEMBER:         call void %[[#]](ptr{{.*}} [ "kcfi"(i32 
[[#%d,MEMBER_LOW_SODIUM_HASH:]]) ]
+// MEMBER:         call void %[[#]](ptr{{.*}} [ "kcfi"(i32 
[[#%d,MEMBER_SALT_HASH:]]) ]
+
+// MEMBER-LABEL: define{{.*}} void @_ZN1A1fEv(ptr{{.*}} %this){{.*}} !kcfi_type
+// MEMBER-SAME:  [[#MEMBER_LOW_SODIUM_TYPE:]]
+// MEMBER-LABEL: define{{.*}} void @_ZN1A1gEv(ptr{{.*}} %this){{.*}} !kcfi_type
+// MEMBER-SAME:  [[#MEMBER_SALT_TYPE:]]
+struct A {
+  void f() {}
+  void __cfi_salt g() {}
+};
+
+void test_member_call(void) {
+  void (A::* p)() = &A::f;
+  (A().*p)();
+
+  void __cfi_salt (A::* q)() = &A::g;
+  (A().*q)();
+}
+#endif
+
+// CHECK:  ![[#]] = !{i32 4, !"kcfi", i32 1}
+// OFFSET: ![[#]] = !{i32 4, !"kcfi-offset", i32 3}
+//
+// CHECK:  ![[#LOW_SODIUM_TYPE]] = !{i32 [[#LOW_SODIUM_HASH]]}
+// CHECK:  ![[#SALTY_TYPE]] = !{i32 [[#SALTY_HASH]]}
+//
+// CHECK:  ![[#LOW_SODIUM_UTYPE]] = !{i32 [[#LOW_SODIUM_UHASH]]}
+// CHECK:  ![[#SALTY_UTYPE]] = !{i32 [[#SALTY_UHASH]]}
+//
+// MEMBER: ![[#MEMBER_LOW_SODIUM_TYPE]] = !{i32 [[#MEMBER_LOW_SODIUM_HASH]]}
+// MEMBER: ![[#MEMBER_SALT_TYPE]] = !{i32 [[#MEMBER_SALT_HASH]]}

>From 04d2c0b251308feb2a5a5518e047bd1c5fa39cd3 Mon Sep 17 00:00:00 2001
From: Bill Wendling <mo...@google.com>
Date: Mon, 2 Jun 2025 15:44:06 -0700
Subject: [PATCH 2/2] Don't use AttributedType. Instead add it as a trailing
 object to the FunctionProtoType.

---
 clang/include/clang/AST/Type.h                | 62 ++++++++++++++++---
 clang/lib/AST/ASTContext.cpp                  |  2 +
 clang/lib/AST/Type.cpp                        | 14 ++++-
 clang/lib/CodeGen/CGCall.h                    |  7 ---
 clang/lib/CodeGen/CGExpr.cpp                  |  5 +-
 clang/lib/CodeGen/CGExprCXX.cpp               | 14 +++--
 clang/lib/CodeGen/CodeGenFunction.cpp         | 27 +-------
 clang/lib/CodeGen/CodeGenModule.cpp           |  6 +-
 clang/lib/Sema/SemaDeclAttr.cpp               | 27 --------
 clang/lib/Sema/SemaType.cpp                   | 16 +++++
 ...a-attribute-supported-attributes-list.test |  1 +
 11 files changed, 102 insertions(+), 79 deletions(-)

diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 3d84f999567ca..36a42e3341e1f 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -4625,6 +4625,9 @@ class FunctionType : public Type {
     /// [implimits] 8 bits would be enough here.
     unsigned NumExceptionType : 10;
 
+    LLVM_PREFERRED_TYPE(bool)
+    unsigned HasExtraAttributeInfo : 1;
+
     LLVM_PREFERRED_TYPE(bool)
     unsigned HasArmTypeAttributes : 1;
 
@@ -4633,14 +4636,26 @@ class FunctionType : public Type {
     unsigned NumFunctionEffects : 4;
 
     FunctionTypeExtraBitfields()
-        : NumExceptionType(0), HasArmTypeAttributes(false),
-          EffectsHaveConditions(false), NumFunctionEffects(0) {}
+        : NumExceptionType(0), HasExtraAttributeInfo(false),
+          HasArmTypeAttributes(false), EffectsHaveConditions(false),
+          NumFunctionEffects(0) {}
+  };
+
+  /// A holder for extra information from attributes which aren't part of an
+  /// \p AttributedType.
+  struct alignas(void *) FunctionTypeExtraAttributeInfo {
+    /// A CFI "salt" that differentiates functions with the same prototype.
+    StringRef CFISalt;
+
+    operator bool() const { return !CFISalt.empty(); }
+
+    void Profile(llvm::FoldingSetNodeID &ID) const { ID.AddString(CFISalt); }
   };
 
   /// The AArch64 SME ACLE (Arm C/C++ Language Extensions) define a number
   /// of function type attributes that can be set on function types, including
   /// function pointers.
-  enum AArch64SMETypeAttributes : unsigned {
+  enum AArch64SMETypeAttributes : uint16_t {
     SME_NormalFunction = 0,
     SME_PStateSMEnabledMask = 1 << 0,
     SME_PStateSMCompatibleMask = 1 << 1,
@@ -4670,11 +4685,11 @@ class FunctionType : public Type {
   };
 
   static ArmStateValue getArmZAState(unsigned AttrBits) {
-    return (ArmStateValue)((AttrBits & SME_ZAMask) >> SME_ZAShift);
+    return static_cast<ArmStateValue>((AttrBits & SME_ZAMask) >> SME_ZAShift);
   }
 
   static ArmStateValue getArmZT0State(unsigned AttrBits) {
-    return (ArmStateValue)((AttrBits & SME_ZT0Mask) >> SME_ZT0Shift);
+    return static_cast<ArmStateValue>((AttrBits & SME_ZT0Mask) >> 
SME_ZT0Shift);
   }
 
   /// A holder for Arm type attributes as described in the Arm C/C++
@@ -4683,6 +4698,7 @@ class FunctionType : public Type {
   struct alignas(void *) FunctionTypeArmAttributes {
     /// Any AArch64 SME ACLE type attributes that need to be propagated
     /// on declarations and function pointers.
+    LLVM_PREFERRED_TYPE(uint16_t)
     unsigned AArch64SMEAttributes : 9;
 
     FunctionTypeArmAttributes() : AArch64SMEAttributes(SME_NormalFunction) {}
@@ -5160,6 +5176,7 @@ class FunctionProtoType final
       private llvm::TrailingObjects<
           FunctionProtoType, QualType, SourceLocation,
           FunctionType::FunctionTypeExtraBitfields,
+          FunctionType::FunctionTypeExtraAttributeInfo,
           FunctionType::FunctionTypeArmAttributes, FunctionType::ExceptionType,
           Expr *, FunctionDecl *, FunctionType::ExtParameterInfo, Qualifiers,
           FunctionEffect, EffectConditionExpr> {
@@ -5249,15 +5266,20 @@ class FunctionProtoType final
   /// the various bits of extra information about a function prototype.
   struct ExtProtoInfo {
     FunctionType::ExtInfo ExtInfo;
-    unsigned Variadic : 1;
-    unsigned HasTrailingReturn : 1;
-    unsigned AArch64SMEAttributes : 9;
     Qualifiers TypeQuals;
     RefQualifierKind RefQualifier = RQ_None;
     ExceptionSpecInfo ExceptionSpec;
     const ExtParameterInfo *ExtParameterInfos = nullptr;
     SourceLocation EllipsisLoc;
     FunctionEffectsRef FunctionEffects;
+    FunctionTypeExtraAttributeInfo ExtraAttributeInfo;
+
+    LLVM_PREFERRED_TYPE(bool)
+    unsigned Variadic : 1;
+    LLVM_PREFERRED_TYPE(bool)
+    unsigned HasTrailingReturn : 1;
+    LLVM_PREFERRED_TYPE(uint16_t)
+    unsigned AArch64SMEAttributes : 9;
 
     ExtProtoInfo()
         : Variadic(false), HasTrailingReturn(false),
@@ -5276,6 +5298,7 @@ class FunctionProtoType final
     bool requiresFunctionProtoTypeExtraBitfields() const {
       return ExceptionSpec.Type == EST_Dynamic ||
              requiresFunctionProtoTypeArmAttributes() ||
+             requiresFunctionProtoTypeExtraAttributeInfo() ||
              !FunctionEffects.empty();
     }
 
@@ -5283,6 +5306,10 @@ class FunctionProtoType final
       return AArch64SMEAttributes != SME_NormalFunction;
     }
 
+    bool requiresFunctionProtoTypeExtraAttributeInfo() const {
+      return static_cast<bool>(ExtraAttributeInfo);
+    }
+
     void setArmSMEAttribute(AArch64SMETypeAttributes Kind, bool Enable = true) 
{
       if (Enable)
         AArch64SMEAttributes |= Kind;
@@ -5308,6 +5335,11 @@ class FunctionProtoType final
     return hasExtraBitfields();
   }
 
+  unsigned
+  numTrailingObjects(OverloadToken<FunctionTypeExtraAttributeInfo>) const {
+    return hasExtraAttributeInfo();
+  }
+
   unsigned numTrailingObjects(OverloadToken<ExceptionType>) const {
     return getExceptionSpecSize().NumExceptionType;
   }
@@ -5404,6 +5436,12 @@ class FunctionProtoType final
 
   }
 
+  bool hasExtraAttributeInfo() const {
+    return FunctionTypeBits.HasExtraBitfields &&
+           getTrailingObjects<FunctionTypeExtraBitfields>()
+               ->HasExtraAttributeInfo;
+  }
+
   bool hasArmTypeAttributes() const {
     return FunctionTypeBits.HasExtraBitfields &&
            getTrailingObjects<FunctionTypeExtraBitfields>()
@@ -5436,6 +5474,7 @@ class FunctionProtoType final
     EPI.TypeQuals = getMethodQuals();
     EPI.RefQualifier = getRefQualifier();
     EPI.ExtParameterInfos = getExtParameterInfosOrNull();
+    EPI.ExtraAttributeInfo = getExtraAttributeInfo();
     EPI.AArch64SMEAttributes = getAArch64SMEAttributes();
     EPI.FunctionEffects = getFunctionEffects();
     return EPI;
@@ -5619,6 +5658,13 @@ class FunctionProtoType final
     return getTrailingObjects<ExtParameterInfo>();
   }
 
+  /// Return the extra attribute information.
+  FunctionTypeExtraAttributeInfo getExtraAttributeInfo() const {
+    if (hasExtraAttributeInfo())
+      return *getTrailingObjects<FunctionTypeExtraAttributeInfo>();
+    return FunctionTypeExtraAttributeInfo();
+  }
+
   /// Return a bitmask describing the SME attributes on the function type, see
   /// AArch64SMETypeAttributes for their values.
   unsigned getAArch64SMEAttributes() const {
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index f4033c02f0d2c..ce0b4d4a863c6 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -5058,10 +5058,12 @@ QualType ASTContext::getFunctionTypeInternal(
       EPI.ExceptionSpec.Type, EPI.ExceptionSpec.Exceptions.size());
   size_t Size = FunctionProtoType::totalSizeToAlloc<
       QualType, SourceLocation, FunctionType::FunctionTypeExtraBitfields,
+      FunctionType::FunctionTypeExtraAttributeInfo,
       FunctionType::FunctionTypeArmAttributes, FunctionType::ExceptionType,
       Expr *, FunctionDecl *, FunctionProtoType::ExtParameterInfo, Qualifiers,
       FunctionEffect, EffectConditionExpr>(
       NumArgs, EPI.Variadic, EPI.requiresFunctionProtoTypeExtraBitfields(),
+      EPI.requiresFunctionProtoTypeExtraAttributeInfo(),
       EPI.requiresFunctionProtoTypeArmAttributes(), ESH.NumExceptionType,
       ESH.NumExprPtr, ESH.NumFunctionDeclPtr,
       EPI.ExtParameterInfos ? NumArgs : 0,
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 35a5f8ec32ab0..a6c334e21e2c9 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -3672,6 +3672,16 @@ FunctionProtoType::FunctionProtoType(QualType result, 
ArrayRef<QualType> params,
     FunctionTypeBits.HasExtraBitfields = false;
   }
 
+  // Propagate any extra attribute information.
+  if (epi.requiresFunctionProtoTypeExtraAttributeInfo()) {
+    auto &ExtraAttrInfo = 
*getTrailingObjects<FunctionTypeExtraAttributeInfo>();
+    ExtraAttrInfo.CFISalt = epi.ExtraAttributeInfo.CFISalt;
+
+    // Also set the bit in FunctionTypeExtraBitfields.
+    auto &ExtraBits = *getTrailingObjects<FunctionTypeExtraBitfields>();
+    ExtraBits.HasExtraAttributeInfo = true;
+  }
+
   if (epi.requiresFunctionProtoTypeArmAttributes()) {
     auto &ArmTypeAttrs = *getTrailingObjects<FunctionTypeArmAttributes>();
     ArmTypeAttrs = FunctionTypeArmAttributes();
@@ -3889,7 +3899,8 @@ void FunctionProtoType::Profile(llvm::FoldingSetNodeID 
&ID, QualType Result,
   // This is followed by the ext info:
   //      int
   // Finally we have a trailing return type flag (bool)
-  // combined with AArch64 SME Attributes, to save space:
+  // combined with AArch64 SME Attributes and extra attribute info, to save
+  // space:
   //      int
   // combined with any FunctionEffects
   //
@@ -3924,6 +3935,7 @@ void FunctionProtoType::Profile(llvm::FoldingSetNodeID 
&ID, QualType Result,
   }
 
   epi.ExtInfo.Profile(ID);
+  epi.ExtraAttributeInfo.Profile(ID);
 
   unsigned EffectCount = epi.FunctionEffects.size();
   bool HasConds = !epi.FunctionEffects.Conditions.empty();
diff --git a/clang/lib/CodeGen/CGCall.h b/clang/lib/CodeGen/CGCall.h
index f54dd64d182e5..0b4e3f9cb0365 100644
--- a/clang/lib/CodeGen/CGCall.h
+++ b/clang/lib/CodeGen/CGCall.h
@@ -96,8 +96,6 @@ class CGCallee {
     VirtualInfoStorage VirtualInfo;
   };
 
-  QualType ExprType;
-
   explicit CGCallee(SpecialKind kind) : KindOrFunctionPointer(kind) {}
 
   CGCallee(const FunctionDecl *builtinDecl, unsigned builtinID)
@@ -223,11 +221,6 @@ class CGCallee {
     return VirtualInfo.FTy;
   }
 
-  /// Retain the full type of the callee before canonicalization. This may have
-  /// attributes important for later processing (e.g. KCFI).
-  QualType getExprType() const { return ExprType; }
-  void setExprType(QualType Ty) { ExprType = Ty; }
-
   /// If this is a delayed callee computation of some sort, prepare
   /// a concrete callee.
   CGCallee prepareConcreteCallee(CodeGenFunction &CGF) const;
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 3811fa6a875f8..bae28b45afaa3 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -6260,13 +6260,12 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType,
           !cast<FunctionDecl>(TargetDecl)->isImmediateFunction()) &&
          "trying to emit a call to an immediate function");
 
-  CGCallee Callee = OrigCallee;
-  Callee.setExprType(CalleeType);
-
   CalleeType = getContext().getCanonicalType(CalleeType);
 
   auto PointeeType = cast<PointerType>(CalleeType)->getPointeeType();
 
+  CGCallee Callee = OrigCallee;
+
   if (SanOpts.has(SanitizerKind::Function) &&
       (!TargetDecl || !isa<FunctionDecl>(TargetDecl)) &&
       !isa<FunctionNoProtoType>(PointeeType)) {
diff --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp
index 471c1b9a58cdb..024254b0affe4 100644
--- a/clang/lib/CodeGen/CGExprCXX.cpp
+++ b/clang/lib/CodeGen/CGExprCXX.cpp
@@ -470,14 +470,16 @@ CodeGenFunction::EmitCXXMemberPointerCallExpr(const 
CXXMemberCallExpr *E,
 
   // Ask the ABI to load the callee.  Note that This is modified.
   llvm::Value *ThisPtrForCall = nullptr;
-  CGCallee Callee = CGM.getCXXABI().EmitLoadOfMemberFunctionPointer(
-      *this, BO, This, ThisPtrForCall, MemFnPtr, MPT);
-  Callee.setExprType(MemFnExpr->getType());
+  CGCallee Callee =
+    CGM.getCXXABI().EmitLoadOfMemberFunctionPointer(*this, BO, This,
+                                             ThisPtrForCall, MemFnPtr, MPT);
 
-  // Push the this ptr.
-  QualType ThisType =
-      getContext().getPointerType(getContext().getTagDeclType(RD));
   CallArgList Args;
+
+  QualType ThisType =
+    getContext().getPointerType(getContext().getTagDeclType(RD));
+
+  // Push the this ptr.
   Args.add(RValue::get(ThisPtrForCall), ThisType);
 
   RequiredArgs required = RequiredArgs::forPrototypePlus(FPT, 1);
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp 
b/clang/lib/CodeGen/CodeGenFunction.cpp
index 3902b0e0b26c7..9be41d7079958 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -2892,30 +2892,9 @@ void CodeGenFunction::EmitKCFIOperandBundle(
   if (!FP)
     return;
 
-  const AttributedType *AT = nullptr;
-  GlobalDecl GD = CI.getCalleeDecl();
   StringRef Salt;
-
-  auto GetAttributedType = [](QualType Ty) {
-    if (const auto *ET = dyn_cast<ElaboratedType>(Ty))
-      if (const auto *TT = dyn_cast<TypedefType>(ET->desugar()))
-        return dyn_cast<AttributedType>(TT->desugar());
-
-    return dyn_cast<AttributedType>(Ty);
-  };
-
-  if (GD) {
-    if (const auto *D = dyn_cast<ValueDecl>(GD.getDecl()))
-      AT = GetAttributedType(D->getType());
-  } else {
-    QualType ExprTy = Callee.getExprType();
-    if (!ExprTy.isNull())
-      AT = GetAttributedType(ExprTy);
-  }
-
-  if (AT)
-    if (const auto *CFISalt = dyn_cast<CFISaltAttr>(AT->getAttr()))
-      Salt = CFISalt->getSalt();
+  if (const auto &Info = FP->getExtraAttributeInfo())
+    Salt = Info.CFISalt;
 
   Bundles.emplace_back("kcfi", CGM.CreateKCFITypeId(FP->desugar(), Salt));
 }
@@ -3376,4 +3355,4 @@ void 
CodeGenFunction::addInstToNewSourceAtom(llvm::Instruction *KeyInstruction,
     ApplyAtomGroup Grp(getDebugInfo());
     DI->addInstToCurrentSourceAtom(KeyInstruction, Backup);
   }
-}
\ No newline at end of file
+}
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp 
b/clang/lib/CodeGen/CodeGenModule.cpp
index ad2e4a714deba..443a21685c14b 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -2903,9 +2903,9 @@ void CodeGenModule::setKCFIType(const FunctionDecl *FD, 
llvm::Function *F) {
   llvm::MDBuilder MDB(Ctx);
   llvm::StringRef Salt;
 
-  if (const auto *AT = dyn_cast<AttributedType>(FD->getType()))
-    if (const auto *CFISalt = dyn_cast<CFISaltAttr>(AT->getAttr()))
-      Salt = CFISalt->getSalt();
+  if (const auto *FP = FD->getType()->getAs<FunctionProtoType>())
+    if (const auto &Info = FP->getExtraAttributeInfo())
+      Salt = Info.CFISalt;
 
   F->setMetadata(llvm::LLVMContext::MD_kcfi_type,
                  llvm::MDNode::get(Ctx, MDB.createConstant(CreateKCFITypeId(
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index b7bd1fd48cd0a..54bac40982eda 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -6671,29 +6671,6 @@ static void handleEnforceTCBAttr(Sema &S, Decl *D, const 
ParsedAttr &AL) {
   D->addAttr(AttrTy::Create(S.Context, Argument, AL));
 }
 
-static void handleCFISaltAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
-  StringRef Argument;
-  if (!S.checkStringLiteralArgumentAttr(AL, 0, Argument))
-    return;
-
-  auto *Attr = CFISaltAttr::Create(S.Context, Argument, AL);
-
-  if (auto *DD = dyn_cast<ValueDecl>(D)) {
-    QualType Ty = DD->getType();
-    if (!Ty->getAs<AttributedType>())
-      DD->setType(S.Context.getAttributedType(Attr, Ty, Ty));
-  } else if (auto *TD = dyn_cast<TypedefDecl>(D)) {
-    TypeSourceInfo *TSI = TD->getTypeSourceInfo();
-    QualType Ty = TD->getUnderlyingType();
-
-    if (!Ty->hasAttr(attr::CFISalt))
-      TD->setModedTypeSourceInfo(TSI,
-                                 S.Context.getAttributedType(Attr, Ty, Ty));
-  } else {
-    llvm_unreachable("Unexpected Decl argument type!");
-  }
-}
-
 template <typename AttrTy, typename ConflictingAttrTy>
 static AttrTy *mergeEnforceTCBAttrImpl(Sema &S, Decl *D, const AttrTy &AL) {
   // Check if the new redeclaration has different leaf-ness in the same TCB.
@@ -7496,10 +7473,6 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, 
const ParsedAttr &AL,
     handleCountedByAttrField(S, D, AL);
     break;
 
-  case ParsedAttr::AT_CFISalt:
-    handleCFISaltAttr(S, D, AL);
-    break;
-
   // Microsoft attributes:
   case ParsedAttr::AT_LayoutVersion:
     handleLayoutVersion(S, D, AL);
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 338b81fe89748..5f05ccdc308d8 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -163,6 +163,7 @@ static void diagnoseBadTypeAttribute(Sema &S, const 
ParsedAttr &attr,
   case ParsedAttr::AT_ArmOut:                                                  
\
   case ParsedAttr::AT_ArmInOut:                                                
\
   case ParsedAttr::AT_ArmAgnostic:                                             
\
+  case ParsedAttr::AT_CFISalt:                                                 
\
   case ParsedAttr::AT_AnyX86NoCallerSavedRegisters:                            
\
   case ParsedAttr::AT_AnyX86NoCfCheck:                                         
\
     CALLING_CONV_ATTRS_CASELIST
@@ -7921,6 +7922,21 @@ static bool handleFunctionTypeAttr(TypeProcessingState 
&state, ParsedAttr &attr,
     return true;
   }
 
+  if (attr.getKind() == ParsedAttr::AT_CFISalt) {
+    StringRef Argument;
+    if (!S.checkStringLiteralArgumentAttr(attr, 0, Argument))
+      return false;
+
+    const auto *FnTy = unwrapped.get()->getAs<FunctionProtoType>();
+    FunctionProtoType::ExtProtoInfo EPI = FnTy->getExtProtoInfo();
+    EPI.ExtraAttributeInfo.CFISalt = Argument;
+
+    QualType newtype = S.Context.getFunctionType(FnTy->getReturnType(),
+                                                 FnTy->getParamTypes(), EPI);
+    type = unwrapped.wrap(S, newtype->getAs<FunctionType>());
+    return true;
+  }
+
   if (attr.getKind() == ParsedAttr::AT_ArmStreaming ||
       attr.getKind() == ParsedAttr::AT_ArmStreamingCompatible ||
       attr.getKind() == ParsedAttr::AT_ArmPreserves ||
diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test 
b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
index bf64c388b0436..6ffb7cf1d4140 100644
--- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test
+++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
@@ -31,6 +31,7 @@
 // CHECK-NEXT: CFConsumed (SubjectMatchRule_variable_is_parameter)
 // CHECK-NEXT: CFGuard (SubjectMatchRule_function)
 // CHECK-NEXT: CFICanonicalJumpTable (SubjectMatchRule_function)
+// CHECK-NEXT: CFISalt (SubjectMatchRule_function, SubjectMatchRule_field, 
SubjectMatchRule_variable, SubjectMatchRule_type_alias)
 // CHECK-NEXT: CFUnknownTransfer (SubjectMatchRule_function)
 // CHECK-NEXT: CPUDispatch (SubjectMatchRule_function)
 // CHECK-NEXT: CPUSpecific (SubjectMatchRule_function)

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to