https://github.com/spall created 
https://github.com/llvm/llvm-project/pull/140627

In 'EmitStoreThroughExtVectorComponentLValue', move the code which ZExts in the 
case the Destination Scalar Type is larger than the Source Scalar Type, to the 
top of the function, to ensure each condition is handled.

The previous code missed this case:
```
bool4 b = true.xxxx;
b.xyz = false.xxx;
```
Leading to a bad shuffle vector. 

Closes #140564

>From 34ae69db10fec2ccdd26790f71fcdb3fc612e640 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahsp...@microsoft.com>
Date: Sun, 18 May 2025 19:16:49 -0700
Subject: [PATCH] zext at top of function instead of in each case

update tests + new test
---
 clang/lib/CodeGen/CGExpr.cpp                  | 22 +++++++---------
 .../CodeGenHLSL/builtins/ScalarSwizzles.hlsl  | 25 ++++++++++++++++---
 2 files changed, 31 insertions(+), 16 deletions(-)

diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 37a5678aa61d5..7580a490a66c1 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2694,14 +2694,20 @@ void 
CodeGenFunction::EmitStoreThroughBitfieldLValue(RValue Src, LValue Dst,
 
 void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
                                                                LValue Dst) {
+  llvm::Value *SrcVal = Src.getScalarVal();
+  Address DstAddr = Dst.getExtVectorAddress();
+  if (DstAddr.getElementType()->getScalarSizeInBits() >
+      SrcVal->getType()->getScalarSizeInBits())
+    SrcVal = Builder.CreateZExt(
+        SrcVal, convertTypeForLoadStore(Dst.getType(), SrcVal->getType()));
+
   // HLSL allows storing to scalar values through ExtVector component LValues.
   // To support this we need to handle the case where the destination address 
is
   // a scalar.
-  Address DstAddr = Dst.getExtVectorAddress();
   if (!DstAddr.getElementType()->isVectorTy()) {
     assert(!Dst.getType()->isVectorType() &&
            "this should only occur for non-vector l-values");
-    Builder.CreateStore(Src.getScalarVal(), DstAddr, 
Dst.isVolatileQualified());
+    Builder.CreateStore(SrcVal, DstAddr, Dst.isVolatileQualified());
     return;
   }
 
@@ -2722,11 +2728,6 @@ void 
CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
       for (unsigned i = 0; i != NumSrcElts; ++i)
         Mask[getAccessedFieldNo(i, Elts)] = i;
 
-      llvm::Value *SrcVal = Src.getScalarVal();
-      if (VecTy->getScalarSizeInBits() >
-          SrcVal->getType()->getScalarSizeInBits())
-        SrcVal = Builder.CreateZExt(SrcVal, VecTy);
-
       Vec = Builder.CreateShuffleVector(SrcVal, Mask);
     } else if (NumDstElts > NumSrcElts) {
       // Extended the source vector to the same length and then shuffle it
@@ -2737,8 +2738,7 @@ void 
CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
       for (unsigned i = 0; i != NumSrcElts; ++i)
         ExtMask.push_back(i);
       ExtMask.resize(NumDstElts, -1);
-      llvm::Value *ExtSrcVal =
-          Builder.CreateShuffleVector(Src.getScalarVal(), ExtMask);
+      llvm::Value *ExtSrcVal = Builder.CreateShuffleVector(SrcVal, ExtMask);
       // build identity
       SmallVector<int, 4> Mask;
       for (unsigned i = 0; i != NumDstElts; ++i)
@@ -2764,10 +2764,6 @@ void 
CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
     unsigned InIdx = getAccessedFieldNo(0, Elts);
     llvm::Value *Elt = llvm::ConstantInt::get(SizeTy, InIdx);
 
-    llvm::Value *SrcVal = Src.getScalarVal();
-    if (VecTy->getScalarSizeInBits() > 
SrcVal->getType()->getScalarSizeInBits())
-      SrcVal = Builder.CreateZExt(SrcVal, VecTy->getScalarType());
-
     Vec = Builder.CreateInsertElement(Vec, SrcVal, Elt);
   }
 
diff --git a/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl 
b/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl
index 96e17046ee934..8a3958ad8fd04 100644
--- a/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl
@@ -233,7 +233,8 @@ int AssignInt(int V){
 
 // CHECK: lor.end:
 // CHECK-NEXT: [[H:%.*]] = phi i1 [ true, %entry ], [ [[G]], %lor.rhs ]
-// CHECK-NEXT: store i1 [[H]], ptr [[XAddr]], align 4
+// CHECK-NEXT: [[J:%.*]] = zext i1 %9 to i32
+// CHECK-NEXT: store i32 [[J]], ptr [[XAddr]], align 4
 // CHECK-NEXT: [[I:%.*]] = load i32, ptr [[XAddr]], align 4
 // CHECK-NEXT: [[LoadV:%.*]] = trunc i32 [[I]] to i1
 // CHECK-NEXT: ret i1 [[LoadV]]
@@ -257,8 +258,8 @@ bool AssignBool(bool V) {
 // CHECK-NEXT: store <2 x i32> [[A]], ptr [[X]], align 8
 // CHECK-NEXT: [[B:%.*]] = load i32, ptr [[VAddr]], align 4
 // CHECK-NEXT: [[LV1:%.*]] = trunc i32 [[B]] to i1
-// CHECK-NEXT: [[C:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[D:%.*]] = zext i1 [[LV1]] to i32
+// CHECK-NEXT: [[C:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[E:%.*]] = insertelement <2 x i32> [[C]], i32 [[D]], i32 1
 // CHECK-NEXT: store <2 x i32> [[E]], ptr [[X]], align 8
 // CHECK-NEXT: ret void
@@ -275,8 +276,8 @@ void AssignBool2(bool V) {
 // CHECK-NEXT: store <2 x i32> splat (i32 1), ptr [[X]], align 8
 // CHECK-NEXT: [[Z:%.*]] = load <2 x i32>, ptr [[VAddr]], align 8
 // CHECK-NEXT: [[LV:%.*]] = trunc <2 x i32> [[Z]] to <2 x i1>
-// CHECK-NEXT: [[A:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[B:%.*]] = zext <2 x i1> [[LV]] to <2 x i32>
+// CHECK-NEXT: [[A:%.*]] = load <2 x i32>, ptr [[X]], align 8
 // CHECK-NEXT: [[C:%.*]] = shufflevector <2 x i32> [[B]], <2 x i32> poison, <2 
x i32> <i32 0, i32 1>
 // CHECK-NEXT: store <2 x i32> [[C]], ptr [[X]], align 8
 // CHECK-NEXT: ret void
@@ -302,3 +303,21 @@ bool2 AccessBools() {
   bool4 X = true.xxxx;
   return X.zw;
 }
+
+// CHECK-LABEL: define void {{.*}}BoolSizeMismatch{{.*}}
+// CHECK: [[B:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: [[Tmp:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: store <4 x i32> splat (i32 1), ptr [[B]], align 16
+// CHECK-NEXT: store <1 x i32> zeroinitializer, ptr [[Tmp]], align 4
+// CHECK-NEXT: [[L0:%.*]] = load <1 x i32>, ptr [[Tmp]], align 4
+// CHECK-NEXT: [[L1:%.*]] = shufflevector <1 x i32> [[L0]], <1 x i32> poison, 
<3 x i32> zeroinitializer
+// CHECK-NEXT: [[TruncV:%.*]] = trunc <3 x i32> [[L1]] to <3 x i1>
+// CHECK-NEXT: [[L2:%.*]] = zext <3 x i1> [[TruncV]] to <3 x i32>
+// CHECK-NEXT: [[L3:%.*]] = load <4 x i32>, ptr [[B]], align 16
+// CHECK-NEXT: [[L4:%.*]] = shufflevector <3 x i32> [[L2]], <3 x i32> poison, 
<4 x i32> <i32 0, i32 1, i32 2, i32 poison>
+// CHECK-NEXT: [[L5:%.*]] = shufflevector <4 x i32> [[L3]], <4 x i32> [[L4]], 
<4 x i32> <i32 4, i32 5, i32 6, i32 3>
+// CHECK-NEXT: store <4 x i32> [[L5]], ptr [[B]], align 16
+void BoolSizeMismatch() {
+  bool4 B = {true,true,true,true};
+  B.xyz = false.xxx;
+}

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to