zoecarver created this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.
zoecarver requested review of this revision.

Adds `__builtin_zero_non_value_bits` to zero all padding bits of a
struct.

Currently does not support unions or bitfields.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D87974

Files:
  clang/include/clang/Basic/Builtins.def
  clang/lib/CodeGen/CGBuiltin.cpp
  clang/lib/Sema/SemaChecking.cpp
  clang/test/CodeGenCXX/builtin-zero-non-value-bits.cpp
  clang/test/SemaCXX/builtin-zero-non-value-bits.cpp

Index: clang/test/SemaCXX/builtin-zero-non-value-bits.cpp
===================================================================
--- /dev/null
+++ clang/test/SemaCXX/builtin-zero-non-value-bits.cpp
@@ -0,0 +1,11 @@
+// RUN: %clang_cc1 -fsyntax-only -verify %s
+
+struct Foo { };
+
+void test(int a, Foo b, void* c, int *d, Foo *e) {
+  __builtin_zero_non_value_bits(a); // expected-error {{passing 'int' to parameter of incompatible type structure pointer: type mismatch at 1st parameter ('int' vs structure pointer)}}
+  __builtin_zero_non_value_bits(b); // expected-error {{passing 'Foo' to parameter of incompatible type structure pointer: type mismatch at 1st parameter ('Foo' vs structure pointer)}}
+  __builtin_zero_non_value_bits(c); // expected-error {{passing 'void *' to parameter of incompatible type structure pointer: type mismatch at 1st parameter ('void *' vs structure pointer)}}
+  __builtin_zero_non_value_bits(d); // expected-error {{passing 'int *' to parameter of incompatible type structure pointer: type mismatch at 1st parameter ('int *' vs structure pointer)}}
+  __builtin_zero_non_value_bits(e); // This should not error.
+}
Index: clang/test/CodeGenCXX/builtin-zero-non-value-bits.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGenCXX/builtin-zero-non-value-bits.cpp
@@ -0,0 +1,171 @@
+// RUN: mkdir -p %t
+// RUN: %clang++ %s -o %t/run
+// RUN: %t/run
+
+#include <cstdio>
+#include <cstring>
+#include <cassert>
+
+template<size_t A1, size_t A2, class T>
+struct alignas(A1) BasicWithPadding {
+  T x;
+  alignas(A2) T y;
+};
+
+template<size_t A1, size_t A2, size_t N, class T>
+struct alignas(A1) ArrWithPadding {
+  T x[N];
+  alignas(A2) char c;
+  T y[N];
+};
+
+template<size_t A1, size_t A2, class T>
+struct alignas(A1) PtrWithPadding {
+  T *x;
+  alignas(A2) T *y;
+};
+
+template<size_t A1, size_t A2, size_t A3, class T>
+struct alignas(A1) ThreeWithPadding {
+  T x;
+  alignas(A2) T y;
+  alignas(A3) T z;
+};
+
+template<class T>
+struct Normal {
+  T a;
+  T b;
+};
+
+template<class T>
+struct X {
+  T x;
+};
+
+template<class T>
+struct Z {
+  T z;
+};
+
+template<size_t A, class T>
+struct YZ : public Z<T> {
+  alignas(A) T y;
+};
+
+template<size_t A1, size_t A2, class T>
+struct alignas(A1) HasBase : public X<T>, public YZ<A2, T> {
+  T a;
+  alignas(A2) T b;
+};
+
+template<size_t A1, size_t A2, class T>
+void testAllForType(T a, T b, T c, T d) {
+  using B = BasicWithPadding<A1, A2, T>;
+  B basic1;
+  memset(&basic1, 0, sizeof(B));
+  basic1.x = a;
+  basic1.y = b;
+  B basic2;
+  memset(&basic2, 42, sizeof(B));
+  basic2.x = a;
+  basic2.y = b;
+  assert(memcmp(&basic1, &basic2, sizeof(B)) != 0);
+  __builtin_zero_non_value_bits(&basic2);
+  assert(memcmp(&basic1, &basic2, sizeof(B)) == 0);
+
+  using A = ArrWithPadding<A1, A2, 2, T>;
+  A arr1;
+  memset(&arr1, 0, sizeof(A));
+  arr1.x[0] = a;
+  arr1.x[1] = b;
+  arr1.y[0] = c;
+  arr1.y[1] = d;
+  A arr2;
+  memset(&arr2, 42, sizeof(A));
+  arr2.x[0] = a;
+  arr2.x[1] = b;
+  arr2.y[0] = c;
+  arr2.y[1] = d;
+  arr2.c = 0;
+  assert(memcmp(&arr1, &arr2, sizeof(A)) != 0);
+  __builtin_zero_non_value_bits(&arr2);
+  assert(memcmp(&arr1, &arr2, sizeof(A)) == 0);
+
+  using P = PtrWithPadding<A1, A2, T>;
+  P ptr1;
+  memset(&ptr1, 0, sizeof(P));
+  ptr1.x = &a;
+  ptr1.y = &b;
+  P ptr2;
+  memset(&ptr2, 42, sizeof(P));
+  ptr2.x = &a;
+  ptr2.y = &b;
+  assert(memcmp(&ptr1, &ptr2, sizeof(P)) != 0);
+  __builtin_zero_non_value_bits(&ptr2);
+  assert(memcmp(&ptr1, &ptr2, sizeof(P)) == 0);
+
+  using Three = ThreeWithPadding<A1, A2, A2, T>;
+  Three three1;
+  memset(&three1, 0, sizeof(Three));
+  three1.x = a;
+  three1.y = b;
+  three1.z = c;
+  Three three2;
+  memset(&three2, 42, sizeof(Three));
+  three2.x = a;
+  three2.y = b;
+  three2.z = c;
+  __builtin_zero_non_value_bits(&three2);
+  assert(memcmp(&three1, &three2, sizeof(Three)) == 0);
+
+  using N = Normal<T>;
+  N normal1;
+  memset(&normal1, 0, sizeof(N));
+  normal1.a = a;
+  normal1.b = b;
+  N normal2;
+  memset(&normal2, 42, sizeof(N));
+  normal2.a = a;
+  normal2.b = b;
+  __builtin_zero_non_value_bits(&normal2);
+  assert(memcmp(&normal1, &normal2, sizeof(N)) == 0);
+
+  using H = HasBase<A1, A2, T>;
+  H base1;
+  memset(&base1, 0, sizeof(H));
+  base1.a = a;
+  base1.b = b;
+  base1.x = c;
+  base1.y = d;
+  base1.z = a;
+  H base2;
+  memset(&base2, 42, sizeof(H));
+  base2.a = a;
+  base2.b = b;
+  base2.x = c;
+  base2.y = d;
+  base2.z = a;
+  assert(memcmp(&base1, &base2, sizeof(H)) != 0);
+  __builtin_zero_non_value_bits(&base2);
+  unsigned i = 0;
+  assert(memcmp(&base1, &base2, sizeof(H)) == 0);
+}
+
+struct Foo {
+  int x;
+  int y;
+};
+
+int main() {
+  testAllForType<32, 16, char>(11, 22, 33, 44);
+  testAllForType<64, 32, char>(4, 5, 6, 7);
+  testAllForType<32, 16, int>(0, 1, 2, 3);
+  testAllForType<64, 32, int>(4, 5, 6, 7);
+  testAllForType<32, 16, double>(0, 1, 2, 3);
+  testAllForType<64, 32, double>(4, 5, 6, 7);
+  testAllForType<32, 16, Foo>(Foo{1, 2}, Foo{3, 4}, Foo{1, 2}, Foo{3, 4});
+  testAllForType<64, 32, Foo>(Foo{1, 2}, Foo{3, 4}, Foo{1, 2}, Foo{3, 4});
+
+  return 0;
+}
Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1593,6 +1593,18 @@
   }
   case Builtin::BI__builtin_launder:
     return SemaBuiltinLaunder(*this, TheCall);
+  case Builtin::BI__builtin_zero_non_value_bits: {
+    const Expr *PtrArg = TheCall->getArg(0)->IgnoreParenImpCasts();
+    const QualType PtrArgType = PtrArg->getType();
+    if (!PtrArgType->isPointerType() ||
+        !PtrArgType->getPointeeType()->isRecordType()) {
+      Diag(PtrArg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
+          << PtrArgType << "structure pointer" << 1 << 0 << 3 << 1 << PtrArgType
+          << "structure pointer";
+      return ExprError();
+    }
+    break;
+  }
   case Builtin::BI__sync_fetch_and_add:
   case Builtin::BI__sync_fetch_and_add_1:
   case Builtin::BI__sync_fetch_and_add_2:
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -1642,6 +1642,71 @@
   return RValue::get(Builder.CreateCall(F, { Src, Src, ShiftAmt }));
 }
 
+static void RecursivelyZeroNonValueBits(CodeGenFunction &CGF, Value *Ptr,
+                                        QualType Ty) {
+  auto *I8Ptr = CGF.Builder.CreateBitCast(Ptr, CGF.Int8PtrTy);
+  auto *Zero = ConstantInt::get(CGF.Int8Ty, 0);
+  auto WriteZeroAtOffset = [&](size_t Offset) {
+    auto Index = ConstantInt::get(CGF.IntTy, Offset);
+    auto Element = CGF.Builder.CreateGEP(I8Ptr, Index);
+    CGF.Builder.CreateAlignedStore(Zero, Element, MaybeAlign());
+  };
+  auto GetStructLayout = [&CGF](llvm::Type *Ty) {
+    auto ST = cast<StructType>(Ty);
+    return CGF.CGM.getModule().getDataLayout().getStructLayout(ST);
+  };
+
+  auto ST = cast<StructType>(Ptr->getType()->getPointerElementType());
+  auto SL = GetStructLayout(ST);
+  auto R = cast<CXXRecordDecl>(Ty->getAsRecordDecl());
+  const ASTRecordLayout &ASTLayout = CGF.getContext().getASTRecordLayout(R);
+  size_t RunningOffset = 0;
+  for (auto Base : R->bases()) {
+    // Zero padding between base elements.
+    auto BaseRecord = cast<CXXRecordDecl>(Base.getType()->getAsRecordDecl());
+    auto Offset = static_cast<size_t>(
+      ASTLayout.getBaseClassOffset(BaseRecord).getQuantity());
+    for (; RunningOffset < Offset; ++RunningOffset) {
+      WriteZeroAtOffset(RunningOffset);
+    }
+    // Recursively zero out base classes.
+    auto Index = SL->getElementContainingOffset(Offset);
+    auto BaseElement = CGF.Builder.CreateStructGEP(Ptr, Index);
+    RecursivelyZeroNonValueBits(CGF, BaseElement, Base.getType());
+    // Use the LLVM StructType data layout so we pick up on packed types.
+    auto SL = GetStructLayout(ST->getElementType(Index));
+    auto Size = SL->getSizeInBytes();
+    RunningOffset = Offset + Size;
+  }
+
+  size_t NumFeilds = std::distance(R->field_begin(), R->field_end());
+  auto CurrentField = R->field_begin();
+  for (size_t I = 0; I < NumFeilds; ++I, ++CurrentField) {
+    // Size needs to be in bytes so we can compare it later.
+    auto Offset = ASTLayout.getFieldOffset(I) / 8;
+    for (; RunningOffset < Offset; ++RunningOffset) {
+      WriteZeroAtOffset(RunningOffset);
+    }
+
+    auto Index = SL->getElementContainingOffset(Offset);
+    // If this field is an object, it may have non-zero padding.
+    if (CurrentField->getType()->isRecordType()) {
+      auto Element = CGF.Builder.CreateStructGEP(Ptr, Index);
+      RecursivelyZeroNonValueBits(CGF, Element, CurrentField->getType());
+    }
+
+    auto Size = CGF.CGM.getModule().getDataLayout()
+      .getTypeSizeInBits(ST->getElementType(Index))
+      .getKnownMinSize() / 8;
+    RunningOffset = Offset + Size;
+  }
+  // Clear all bits after the last field.
+  auto Size = SL->getSizeInBytes();
+  for (; RunningOffset < Size; ++RunningOffset) {
+    WriteZeroAtOffset(RunningOffset);
+  }
+}
+
 RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
                                         const CallExpr *E,
                                         ReturnValueSlot ReturnValue) {
@@ -2946,6 +3011,13 @@
 
     return RValue::get(Ptr);
   }
+  case Builtin::BI__builtin_zero_non_value_bits: {
+    const Expr *Op = E->getArg(0);
+    Value *Address = EmitScalarExpr(Op);
+    auto PointeeTy = Op->getType()->getPointeeType();
+    RecursivelyZeroNonValueBits(*this, Address, PointeeTy);
+    return RValue::get(nullptr);
+  }
   case Builtin::BI__sync_fetch_and_add:
   case Builtin::BI__sync_fetch_and_sub:
   case Builtin::BI__sync_fetch_and_or:
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -534,6 +534,7 @@
 BUILTIN(__builtin_thread_pointer, "v*", "nc")
 BUILTIN(__builtin_launder, "v*v*", "nt")
 LANGBUILTIN(__builtin_is_constant_evaluated, "b", "n", CXX_LANG)
+LANGBUILTIN(__builtin_zero_non_value_bits, "v.", "n", CXX_LANG)
 
 // GCC exception builtins
 BUILTIN(__builtin_eh_return, "vzv*", "r") // FIXME: Takes intptr_t, not size_t!
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to