https://github.com/spall updated https://github.com/llvm/llvm-project/pull/118992
>From e994824f3630ee8b224afceb6c14d980c9013112 Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Fri, 6 Dec 2024 05:14:17 +0000 Subject: [PATCH 01/13] splat cast wip --- clang/include/clang/AST/OperationKinds.def | 3 ++ clang/include/clang/Sema/SemaHLSL.h | 1 + clang/lib/CodeGen/CGExprAgg.cpp | 42 ++++++++++++++++++++++ clang/lib/CodeGen/CGExprScalar.cpp | 16 +++++++++ clang/lib/Sema/Sema.cpp | 1 + clang/lib/Sema/SemaCast.cpp | 9 ++++- clang/lib/Sema/SemaHLSL.cpp | 26 ++++++++++++++ 7 files changed, 97 insertions(+), 1 deletion(-) diff --git a/clang/include/clang/AST/OperationKinds.def b/clang/include/clang/AST/OperationKinds.def index b3dc7c3d8dc77e1..333fc7e1b18821e 100644 --- a/clang/include/clang/AST/OperationKinds.def +++ b/clang/include/clang/AST/OperationKinds.def @@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue) // Aggregate by Value cast (HLSL only). CAST_OPERATION(HLSLElementwiseCast) +// Splat cast for Aggregates (HLSL only). +CAST_OPERATION(HLSLSplatCast) + //===- Binary Operations -------------------------------------------------===// // Operators listed in order of precedence. // Note that additions to this should also update the StmtVisitor class, diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 6e8ca2e4710dec8..7508b149b0d81d0 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -144,6 +144,7 @@ class SemaHLSL : public SemaBase { bool CanPerformScalarCast(QualType SrcTy, QualType DestTy); bool ContainsBitField(QualType BaseTy); bool CanPerformElementwiseCast(Expr *Src, QualType DestType); + bool CanPerformSplat(Expr *Src, QualType DestType); ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg); QualType getInoutParameterType(QualType Ty); diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index c3f1cbed6b39f95..f26189bc4907cea 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -491,6 +491,33 @@ static bool isTrivialFiller(Expr *E) { return false; } +static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal, + QualType DestTy, llvm::Value *SrcVal, + QualType SrcTy, SourceLocation Loc) { + // Flatten our destination + SmallVector<QualType> DestTypes; // Flattened type + SmallVector<llvm::Value *, 4> IdxList; + SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList; + // ^^ Flattened accesses to DestVal we want to store into + CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, + DestTypes); + + if (const VectorType *VT = SrcTy->getAs<VectorType>()) { + assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast."); + + SrcTy = VT->getElementType(); + SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, + "vec.load"); + } + assert(SrcTy->isScalarType() && "Invalid HLSL splat cast."); + for(unsigned i = 0; i < StoreGEPList.size(); i ++) { + llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy, + DestTypes[i], + Loc); + CGF.PerformStore(StoreGEPList[i], Cast); + } +} + // emit a flat cast where the RHS is a scalar, including vector static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal, QualType DestTy, llvm::Value *SrcVal, @@ -963,6 +990,21 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) { case CK_HLSLArrayRValue: Visit(E->getSubExpr()); break; + case CK_HLSLSplatCast: { + Expr *Src = E->getSubExpr(); + QualType SrcTy = Src->getType(); + RValue RV = CGF.EmitAnyExpr(Src); + QualType DestTy = E->getType(); + Address DestVal = Dest.getAddress(); + SourceLocation Loc = E->getExprLoc(); + + if (RV.isScalar()) { + llvm::Value *SrcVal = RV.getScalarVal(); + EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc); + break; + } + llvm_unreachable("RHS of HLSL splat cast must be a scalar or vector."); + } case CK_HLSLElementwiseCast: { Expr *Src = E->getSubExpr(); QualType SrcTy = Src->getType(); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 80daed7e5395193..7dc2682bae42f2e 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -2795,6 +2795,22 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy); return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc"); } + case CK_HLSLSplatCast: { + assert(DestTy->isVectorType() && "Destination type must be a vector."); + auto *DestVecTy = DestTy->getAs<VectorType>(); + QualType SrcTy = E->getType(); + SourceLocation Loc = CE->getExprLoc(); + Value *V = Visit(const_cast<Expr *>(E)); + if (auto *VecTy = SrcTy->getAs<VectorType>()) { + assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast."); + V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load"); + SrcTy = VecTy->getElementType(); + } + assert(SrcTy->isScalarType() && "Invalid HLSL splat cast."); + Value *Cast = EmitScalarConversion(V, SrcTy, + DestVecTy->getElementType(), Loc); + return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat"); + } case CK_HLSLElementwiseCast: { RValue RV = CGF.EmitAnyExpr(E); SourceLocation Loc = CE->getExprLoc(); diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp index 15c18f9a4525b22..9eeefbb3c002329 100644 --- a/clang/lib/Sema/Sema.cpp +++ b/clang/lib/Sema/Sema.cpp @@ -709,6 +709,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty, case CK_ToVoid: case CK_NonAtomicToAtomic: case CK_HLSLArrayRValue: + case CK_HLSLSplatCast: break; } } diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp index 23be71ad8e2aebc..56d8396b1e9d41a 100644 --- a/clang/lib/Sema/SemaCast.cpp +++ b/clang/lib/Sema/SemaCast.cpp @@ -2776,9 +2776,16 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle, CheckedConversionKind CCK = FunctionalStyle ? CheckedConversionKind::FunctionalCast : CheckedConversionKind::CStyleCast; + // This case should not trigger on regular vector splat - // vector cast, vector truncation, or special hlsl splat cases QualType SrcTy = SrcExpr.get()->getType(); + if (Self.getLangOpts().HLSL && + Self.HLSL().CanPerformSplat(SrcExpr.get(), DestType)) { + Kind = CK_HLSLSplatCast; + return; + } + + // This case should not trigger on regular vector cast, vector truncation if (Self.getLangOpts().HLSL && Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) { if (SrcTy->isConstantArrayType()) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index ec6b5b45de42bfa..7c9365787fd4fb5 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2804,6 +2804,32 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) { return false; } +// Can perform an HLSL splat cast if the Dest is an aggregate and the +// Src is a scalar or a vector of length 1 +// Or if Dest is a vector and Src is a vector of length 1 +bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) { + + QualType SrcTy = Src->getType(); + if (SrcTy->isScalarType() && DestTy->isVectorType()) + return false; + + const VectorType *SrcVecTy = SrcTy->getAs<VectorType>(); + if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1))) + return false; + + if (SrcVecTy) + SrcTy = SrcVecTy->getElementType(); + + llvm::SmallVector<QualType> DestTypes; + BuildFlattenedTypeList(DestTy, DestTypes); + + for(unsigned i = 0; i < DestTypes.size(); i ++) { + if (!CanPerformScalarCast(SrcTy, DestTypes[i])) + return false; + } + return true; +} + // Can we perform an HLSL Elementwise cast? // TODO: update this code when matrices are added; see issue #88060 bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) { >From 24bea86dd7a2c39ca9f21480990236dc44df8cf3 Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Fri, 6 Dec 2024 05:19:00 +0000 Subject: [PATCH 02/13] make clang format happy --- clang/lib/CodeGen/CGExprAgg.cpp | 19 ++++++++----------- clang/lib/CodeGen/CGExprScalar.cpp | 7 ++++--- clang/lib/Sema/SemaHLSL.cpp | 2 +- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index f26189bc4907cea..60beabf3a5fd0aa 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -492,28 +492,25 @@ static bool isTrivialFiller(Expr *E) { } static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal, - QualType DestTy, llvm::Value *SrcVal, - QualType SrcTy, SourceLocation Loc) { + QualType DestTy, llvm::Value *SrcVal, + QualType SrcTy, SourceLocation Loc) { // Flatten our destination SmallVector<QualType> DestTypes; // Flattened type SmallVector<llvm::Value *, 4> IdxList; SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList; // ^^ Flattened accesses to DestVal we want to store into - CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, - DestTypes); + CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes); if (const VectorType *VT = SrcTy->getAs<VectorType>()) { assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast."); SrcTy = VT->getElementType(); - SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, - "vec.load"); + SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, "vec.load"); } assert(SrcTy->isScalarType() && "Invalid HLSL splat cast."); - for(unsigned i = 0; i < StoreGEPList.size(); i ++) { - llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy, - DestTypes[i], - Loc); + for (unsigned i = 0; i < StoreGEPList.size(); i++) { + llvm::Value *Cast = + CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc); CGF.PerformStore(StoreGEPList[i], Cast); } } @@ -997,7 +994,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) { QualType DestTy = E->getType(); Address DestVal = Dest.getAddress(); SourceLocation Loc = E->getExprLoc(); - + if (RV.isScalar()) { llvm::Value *SrcVal = RV.getScalarVal(); EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 7dc2682bae42f2e..4a20b693b101fae 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -2807,9 +2807,10 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { SrcTy = VecTy->getElementType(); } assert(SrcTy->isScalarType() && "Invalid HLSL splat cast."); - Value *Cast = EmitScalarConversion(V, SrcTy, - DestVecTy->getElementType(), Loc); - return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat"); + Value *Cast = + EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc); + return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, + "splat"); } case CK_HLSLElementwiseCast: { RValue RV = CGF.EmitAnyExpr(E); diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 7c9365787fd4fb5..024f778f8ffef5b 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2823,7 +2823,7 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) { llvm::SmallVector<QualType> DestTypes; BuildFlattenedTypeList(DestTy, DestTypes); - for(unsigned i = 0; i < DestTypes.size(); i ++) { + for (unsigned i = 0; i < DestTypes.size(); i++) { if (!CanPerformScalarCast(SrcTy, DestTypes[i])) return false; } >From 3575617d436f04eac4faadc17ead8bfe561e7e7c Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Fri, 6 Dec 2024 05:59:12 +0000 Subject: [PATCH 03/13] codegen test --- .../CodeGenHLSL/BasicFeatures/SplatCast.hlsl | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl new file mode 100644 index 000000000000000..05359c1bce0ba35 --- /dev/null +++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl @@ -0,0 +1,87 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s + +// array splat +// CHECK-LABEL: define void {{.*}}call4 +// CHECK: [[B:%.*]] = alloca [2 x i32], align 4 +// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false) +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1 +// CHECK-NEXT: store i32 3, ptr [[G1]], align 4 +// CHECK-NEXT: store i32 3, ptr [[G2]], align 4 +export void call4() { + int B[2] = {1,2}; + B = (int[2])3; +} + +// splat from vector of length 1 +// CHECK-LABEL: define void {{.*}}call8 +// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4 +// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4 +// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 +// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false) +// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 +// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 +// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4 +export void call8() { + int1 A = {1}; + int B[2] = {1,2}; + B = (int[2])A; +} + +// vector splat from vector of length 1 +// CHECK-LABEL: define void {{.*}}call1 +// CHECK: [[B:%.*]] = alloca <1 x float>, align 4 +// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16 +// CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4 +// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i64 0 +// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32 +// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0 +// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer +// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 16 +export void call1() { + float1 B = {1.0}; + int4 A = (int4)B; +} + +struct S { + int X; + float Y; +}; + +// struct splats? +// CHECK-LABEL: define void {{.*}}call3 +// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4 +// CHECK: [[s:%.*]] = alloca %struct.S, align 4 +// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 +// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 +// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 +// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float +// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4 +export void call3() { + int1 A = {1}; + S s = (S)A; +} + +// struct splat from vector of length 1 +// CHECK-LABEL: define void {{.*}}call5 +// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4 +// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4 +// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 +// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 +// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 +// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float +// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4 +export void call5() { + int1 A = {1}; + S s = (S)A; +} >From 288b8dac1c6fa4429c92c566a69da593c2ebb97c Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Fri, 6 Dec 2024 17:38:58 +0000 Subject: [PATCH 04/13] Try to handle Cast in all the places it needs to be handled --- clang/lib/AST/Expr.cpp | 1 + clang/lib/AST/ExprConstant.cpp | 2 ++ clang/lib/CodeGen/CGExprAgg.cpp | 1 + clang/lib/CodeGen/CGExprComplex.cpp | 1 + clang/lib/CodeGen/CGExprConstant.cpp | 1 + clang/lib/Edit/RewriteObjCFoundationAPI.cpp | 1 + clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp | 1 + 7 files changed, 8 insertions(+) diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp index c22aa66ba2cfb3d..bbb475fbb30f269 100644 --- a/clang/lib/AST/Expr.cpp +++ b/clang/lib/AST/Expr.cpp @@ -1957,6 +1957,7 @@ bool CastExpr::CastConsistency() const { case CK_HLSLArrayRValue: case CK_HLSLVectorTruncation: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: CheckNoBasePath: assert(path_empty() && "Cast kind should not have a base path!"); break; diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index 192b679b4c99596..ddc2d008839007e 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -15029,6 +15029,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) { case CK_FixedPointCast: case CK_IntegralToFixedPoint: case CK_MatrixCast: + case CK_HLSLSplatCast: llvm_unreachable("invalid cast kind for integral value"); case CK_BitCast: @@ -15907,6 +15908,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) { case CK_MatrixCast: case CK_HLSLVectorTruncation: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: llvm_unreachable("invalid cast kind for complex value"); case CK_LValueToRValue: diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index 60beabf3a5fd0aa..3584280e2fb9e44 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -1592,6 +1592,7 @@ static bool castPreservesZero(const CastExpr *CE) { case CK_AtomicToNonAtomic: case CK_HLSLVectorTruncation: case CK_HLSLElementwiseCast: + // TODO is this true for CK_HLSLSplatCast return true; case CK_BaseToDerivedMemberPointer: diff --git a/clang/lib/CodeGen/CGExprComplex.cpp b/clang/lib/CodeGen/CGExprComplex.cpp index c2679ea92dc9728..3832b9b598b24e9 100644 --- a/clang/lib/CodeGen/CGExprComplex.cpp +++ b/clang/lib/CodeGen/CGExprComplex.cpp @@ -611,6 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op, case CK_HLSLVectorTruncation: case CK_HLSLArrayRValue: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: llvm_unreachable("invalid cast kind for complex value"); case CK_FloatingRealToComplex: diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp index ef11798869d3b13..b8ce83803b65fde 100644 --- a/clang/lib/CodeGen/CGExprConstant.cpp +++ b/clang/lib/CodeGen/CGExprConstant.cpp @@ -1336,6 +1336,7 @@ class ConstExprEmitter case CK_HLSLVectorTruncation: case CK_HLSLArrayRValue: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: return nullptr; } llvm_unreachable("Invalid CastKind"); diff --git a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp index 32f5ebb55155ed1..10d3f62fcd0a416 100644 --- a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp +++ b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp @@ -1086,6 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg, case CK_HLSLVectorTruncation: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: llvm_unreachable("HLSL-specific cast in Objective-C?"); break; diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp index 3a983421358c7f4..d75583f68eb6b7b 100644 --- a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp +++ b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp @@ -523,6 +523,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex, case CK_MatrixCast: case CK_VectorSplat: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: case CK_HLSLVectorTruncation: { QualType resultType = CastE->getType(); if (CastE->isGLValue()) >From 0650840642960d950d64e234e9641e34096a6c55 Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Wed, 11 Dec 2024 20:54:39 +0000 Subject: [PATCH 05/13] get code compiling after rebase --- clang/lib/CodeGen/CGExprAgg.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index 3584280e2fb9e44..3330cd03628f75e 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -496,10 +496,9 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal, QualType SrcTy, SourceLocation Loc) { // Flatten our destination SmallVector<QualType> DestTypes; // Flattened type - SmallVector<llvm::Value *, 4> IdxList; SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList; // ^^ Flattened accesses to DestVal we want to store into - CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes); + CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes); if (const VectorType *VT = SrcTy->getAs<VectorType>()) { assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast."); @@ -511,7 +510,15 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal, for (unsigned i = 0; i < StoreGEPList.size(); i++) { llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc); - CGF.PerformStore(StoreGEPList[i], Cast); + + // store back + llvm::Value *Idx = StoreGEPList[i].second; + if (Idx) { + llvm::Value *V = + CGF.Builder.CreateLoad(StoreGEPList[i].first, "load.for.insert"); + Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx); + } + CGF.Builder.CreateStore(Cast, StoreGEPList[i].first); } } >From f924b13ada0c3344f3cc4f87a859f0ecd16705cb Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Thu, 12 Dec 2024 00:04:29 +0000 Subject: [PATCH 06/13] Self review --- clang/lib/CodeGen/CGExprScalar.cpp | 15 +++++++----- clang/lib/Sema/SemaHLSL.cpp | 7 +++--- clang/test/SemaHLSL/Language/SplatCasts.hlsl | 25 ++++++++++++++++++++ 3 files changed, 38 insertions(+), 9 deletions(-) create mode 100644 clang/test/SemaHLSL/Language/SplatCasts.hlsl diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 4a20b693b101fae..85c0265ea14b611 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -2796,17 +2796,20 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc"); } case CK_HLSLSplatCast: { + // This code should only handle splatting from vectors of length 1. assert(DestTy->isVectorType() && "Destination type must be a vector."); auto *DestVecTy = DestTy->getAs<VectorType>(); QualType SrcTy = E->getType(); SourceLocation Loc = CE->getExprLoc(); Value *V = Visit(const_cast<Expr *>(E)); - if (auto *VecTy = SrcTy->getAs<VectorType>()) { - assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast."); - V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load"); - SrcTy = VecTy->getElementType(); - } - assert(SrcTy->isScalarType() && "Invalid HLSL splat cast."); + assert(SrcTy->isVectorType() && "Invalid HLSL splat cast."); + + auto *VecTy = SrcTy->getAs<VectorType>(); + assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast."); + + V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load"); + SrcTy = VecTy->getElementType(); + Value *Cast = EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc); return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 024f778f8ffef5b..432a42016789ec2 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2814,12 +2814,13 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) { return false; const VectorType *SrcVecTy = SrcTy->getAs<VectorType>(); - if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1))) - return false; - if (SrcVecTy) SrcTy = SrcVecTy->getElementType(); + // Src isn't a scalar or a vector of length 1 + if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1)) + return false; + llvm::SmallVector<QualType> DestTypes; BuildFlattenedTypeList(DestTy, DestTypes); diff --git a/clang/test/SemaHLSL/Language/SplatCasts.hlsl b/clang/test/SemaHLSL/Language/SplatCasts.hlsl new file mode 100644 index 000000000000000..593a8e67fd4a3b8 --- /dev/null +++ b/clang/test/SemaHLSL/Language/SplatCasts.hlsl @@ -0,0 +1,25 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -finclude-default-header -fnative-half-type %s -ast-dump | FileCheck %s + +// splat from vec1 to vec +// CHECK-LABEL: call1 +// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLSplatCast> +// CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>' +export void call1() { + float1 A = {1.0}; + int3 B = (int3)A; +} + +struct S { + int A; + float B; + int C; + float D; +}; + +// splat from scalar to aggregate +// CHECK-LABEL: call2 +// CHECK: CStyleCastExpr {{.*}} 'S' <HLSLSplatCast> +// CHECK-NEXt: IntegerLiteral {{.*}} 'int' 5 +export void call2() { + S s = (S)5; +} \ No newline at end of file >From 89ceeb7d6b445f10fa6b7deb8c10267cd292da7b Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Thu, 12 Dec 2024 05:59:55 +0000 Subject: [PATCH 07/13] move code back that broke tests --- clang/lib/Sema/SemaHLSL.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 432a42016789ec2..76ca24b10c60a16 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2814,13 +2814,14 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) { return false; const VectorType *SrcVecTy = SrcTy->getAs<VectorType>(); - if (SrcVecTy) - SrcTy = SrcVecTy->getElementType(); // Src isn't a scalar or a vector of length 1 if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1)) return false; + if (SrcVecTy) + SrcTy = SrcVecTy->getElementType(); + llvm::SmallVector<QualType> DestTypes; BuildFlattenedTypeList(DestTy, DestTypes); >From 7f5b3e4f39f2a4cf2d42e5281e70d900878c1a3b Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Thu, 12 Dec 2024 06:08:46 +0000 Subject: [PATCH 08/13] fix tests --- .../CodeGenHLSL/BasicFeatures/SplatCast.hlsl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl index 05359c1bce0ba35..2de68479179dd4c 100644 --- a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl +++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl @@ -4,8 +4,8 @@ // CHECK-LABEL: define void {{.*}}call4 // CHECK: [[B:%.*]] = alloca [2 x i32], align 4 // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false) -// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0 -// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1 // CHECK-NEXT: store i32 3, ptr [[G1]], align 4 // CHECK-NEXT: store i32 3, ptr [[G2]], align 4 export void call4() { @@ -20,8 +20,8 @@ export void call4() { // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false) // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 -// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0 -// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 // CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4 @@ -58,8 +58,8 @@ struct S { // CHECK: [[s:%.*]] = alloca %struct.S, align 4 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 -// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0 -// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float @@ -75,8 +75,8 @@ export void call3() { // CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 -// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0 -// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float >From 844ba82eb5dcfbd0105db2d4943266fa8d009c17 Mon Sep 17 00:00:00 2001 From: Sarah Spall <sarahsp...@microsoft.com> Date: Sat, 8 Feb 2025 09:07:05 -0800 Subject: [PATCH 09/13] add cast to cases --- clang/lib/CodeGen/CGExpr.cpp | 1 + clang/lib/CodeGen/CGExprAgg.cpp | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 2bbc0791c65876f..545d8b11a6a47a9 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -5339,6 +5339,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) { case CK_HLSLVectorTruncation: case CK_HLSLArrayRValue: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: return EmitUnsupportedLValue(E, "unexpected cast lvalue"); case CK_Dependent: diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index 3330cd03628f75e..b7fe62687b074a0 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -1599,7 +1599,7 @@ static bool castPreservesZero(const CastExpr *CE) { case CK_AtomicToNonAtomic: case CK_HLSLVectorTruncation: case CK_HLSLElementwiseCast: - // TODO is this true for CK_HLSLSplatCast + case CK_HLSLSplatCast: return true; case CK_BaseToDerivedMemberPointer: >From 848315d47512a65ac98ecbb2c2102c6c4eef75f8 Mon Sep 17 00:00:00 2001 From: Sarah Spall <sarahsp...@microsoft.com> Date: Mon, 10 Feb 2025 12:21:59 -0800 Subject: [PATCH 10/13] self review --- clang/include/clang/Sema/SemaHLSL.h | 2 +- clang/lib/CodeGen/CGExprAgg.cpp | 26 +++++++------------- clang/lib/CodeGen/CGExprScalar.cpp | 11 +++------ clang/lib/Sema/SemaCast.cpp | 7 +++++- clang/lib/Sema/SemaHLSL.cpp | 2 +- clang/test/SemaHLSL/Language/SplatCasts.hlsl | 2 +- 6 files changed, 21 insertions(+), 29 deletions(-) diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 7508b149b0d81d0..3772301afdd4fe8 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -144,7 +144,7 @@ class SemaHLSL : public SemaBase { bool CanPerformScalarCast(QualType SrcTy, QualType DestTy); bool ContainsBitField(QualType BaseTy); bool CanPerformElementwiseCast(Expr *Src, QualType DestType); - bool CanPerformSplat(Expr *Src, QualType DestType); + bool CanPerformSplatCast(Expr *Src, QualType DestType); ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg); QualType getInoutParameterType(QualType Ty); diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index b7fe62687b074a0..36557ccb15f1a71 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -500,25 +500,19 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal, // ^^ Flattened accesses to DestVal we want to store into CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes); - if (const VectorType *VT = SrcTy->getAs<VectorType>()) { - assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast."); - - SrcTy = VT->getElementType(); - SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, "vec.load"); - } assert(SrcTy->isScalarType() && "Invalid HLSL splat cast."); - for (unsigned i = 0; i < StoreGEPList.size(); i++) { + for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; I++) { llvm::Value *Cast = - CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc); + CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc); // store back - llvm::Value *Idx = StoreGEPList[i].second; + llvm::Value *Idx = StoreGEPList[I].second; if (Idx) { llvm::Value *V = - CGF.Builder.CreateLoad(StoreGEPList[i].first, "load.for.insert"); + CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert"); Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx); } - CGF.Builder.CreateStore(Cast, StoreGEPList[i].first); + CGF.Builder.CreateStore(Cast, StoreGEPList[I].first); } } @@ -1002,12 +996,10 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) { Address DestVal = Dest.getAddress(); SourceLocation Loc = E->getExprLoc(); - if (RV.isScalar()) { - llvm::Value *SrcVal = RV.getScalarVal(); - EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc); - break; - } - llvm_unreachable("RHS of HLSL splat cast must be a scalar or vector."); + assert (RV.isScalar() && "RHS of HLSL splat cast must be a scalar."); + llvm::Value *SrcVal = RV.getScalarVal(); + EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc); + break; } case CK_HLSLElementwiseCast: { Expr *Src = E->getSubExpr(); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 85c0265ea14b611..b09c18f4a1229a2 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -2796,19 +2796,14 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc"); } case CK_HLSLSplatCast: { - // This code should only handle splatting from vectors of length 1. + // This cast should only handle splatting from vectors of length 1. + // But in Sema a cast should have been inserted to convert the vec1 to a scalar assert(DestTy->isVectorType() && "Destination type must be a vector."); auto *DestVecTy = DestTy->getAs<VectorType>(); QualType SrcTy = E->getType(); SourceLocation Loc = CE->getExprLoc(); Value *V = Visit(const_cast<Expr *>(E)); - assert(SrcTy->isVectorType() && "Invalid HLSL splat cast."); - - auto *VecTy = SrcTy->getAs<VectorType>(); - assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast."); - - V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load"); - SrcTy = VecTy->getElementType(); + assert(SrcTy->isBuiltinType() && "Invalid HLSL splat cast."); Value *Cast = EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc); diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp index 56d8396b1e9d41a..a60bc0687461fa6 100644 --- a/clang/lib/Sema/SemaCast.cpp +++ b/clang/lib/Sema/SemaCast.cpp @@ -2780,7 +2780,12 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle, // This case should not trigger on regular vector splat QualType SrcTy = SrcExpr.get()->getType(); if (Self.getLangOpts().HLSL && - Self.HLSL().CanPerformSplat(SrcExpr.get(), DestType)) { + Self.HLSL().CanPerformSplatCast(SrcExpr.get(), DestType)) { + const VectorType *VT = SrcTy->getAs<VectorType>(); + // change splat from vec1 case to splat from scalar + if (VT && VT->getNumElements() == 1) + SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), VT->getElementType(), + CK_HLSLVectorTruncation, VK_PRValue, nullptr, CCK); Kind = CK_HLSLSplatCast; return; } diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 76ca24b10c60a16..d20bd281b7dc9b5 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2807,7 +2807,7 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) { // Can perform an HLSL splat cast if the Dest is an aggregate and the // Src is a scalar or a vector of length 1 // Or if Dest is a vector and Src is a vector of length 1 -bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) { +bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) { QualType SrcTy = Src->getType(); if (SrcTy->isScalarType() && DestTy->isVectorType()) diff --git a/clang/test/SemaHLSL/Language/SplatCasts.hlsl b/clang/test/SemaHLSL/Language/SplatCasts.hlsl index 593a8e67fd4a3b8..cfe3b981dc92cc9 100644 --- a/clang/test/SemaHLSL/Language/SplatCasts.hlsl +++ b/clang/test/SemaHLSL/Language/SplatCasts.hlsl @@ -22,4 +22,4 @@ struct S { // CHECK-NEXt: IntegerLiteral {{.*}} 'int' 5 export void call2() { S s = (S)5; -} \ No newline at end of file +} >From cadd309a7d7c61a592d11cb306d44139df8d15ee Mon Sep 17 00:00:00 2001 From: Sarah Spall <sarahsp...@microsoft.com> Date: Mon, 10 Feb 2025 15:26:45 -0800 Subject: [PATCH 11/13] disallow splatting things with bitvectors, add tests to show casting bitvectors not allowed. At test showing splatting union is not allowed. At test showing splatting union in elementwise cast is not allowed. --- clang/lib/Sema/SemaHLSL.cpp | 11 +++++-- .../Language/ElementwiseCast-errors.hlsl | 20 ++++++++++++ .../SemaHLSL/Language/SplatCasts-errors.hlsl | 32 +++++++++++++++++++ 3 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index d20bd281b7dc9b5..252121f88af0507 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2771,7 +2771,7 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) { } // Detect if a type contains a bitfield. Will be removed when -// bitfield support is added to HLSLElementwiseCast +// bitfield support is added to HLSLElementwiseCast and HLSLSplatCast bool SemaHLSL::ContainsBitField(QualType BaseTy) { llvm::SmallVector<QualType, 16> WorkList; WorkList.push_back(BaseTy); @@ -2822,11 +2822,16 @@ bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) { if (SrcVecTy) SrcTy = SrcVecTy->getElementType(); + if (ContainsBitField(DestTy)) + return false; + llvm::SmallVector<QualType> DestTypes; BuildFlattenedTypeList(DestTy, DestTypes); - for (unsigned i = 0; i < DestTypes.size(); i++) { - if (!CanPerformScalarCast(SrcTy, DestTypes[i])) + for (unsigned I = 0, Size = DestTypes.size(); I < Size; I++) { + if (DestTypes[I]->isUnionType()) + return false; + if (!CanPerformScalarCast(SrcTy, DestTypes[I])) return false; } return true; diff --git a/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl b/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl index c900c83a063a06b..b7085bc69547b52 100644 --- a/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl +++ b/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl @@ -27,3 +27,23 @@ export void cantCast3() { S s = (S)C; // expected-error@-1 {{no matching conversion for C-style cast from 'int2' (aka 'vector<int, 2>') to 'S'}} } + +struct R { +// expected-note@-1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'const R' for 1st argument}} +// expected-note@-2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'R' for 1st argument}} +// expected-note@-3 {{candidate constructor (the implicit default constructor) not viable: requires 0 arguments, but 1 was provided}} + int A; + union { + float F; + int4 G; + }; +}; + +export void cantCast4() { + int2 A = {1,2}; + R r = R(A); + // expected-error@-1 {{no matching conversion for functional-style cast from 'int2' (aka 'vector<int, 2>') to 'R'}} + R r2 = {1, 2}; + int2 B = (int2)r2; + // expected-error@-1 {{cannot convert 'R' to 'int2' (aka 'vector<int, 2>') without a conversion operator}} +} diff --git a/clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl b/clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl new file mode 100644 index 000000000000000..b0234c597eadf08 --- /dev/null +++ b/clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl @@ -0,0 +1,32 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -verify + +struct S { +// expected-note@-1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int' to 'const S' for 1st argument}} +// expected-note@-2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int' to 'S' for 1st argument}} +// expected-note@-3 {{candidate constructor (the implicit default constructor) not viable: requires 0 arguments, but 1 was provided}} + int A : 8; + int B; +}; + +struct R { +// expected-note@-1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int' to 'const R' for 1st argument}} +// expected-note@-2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int' to 'R' for 1st argument}} +// expected-note@-3 {{candidate constructor (the implicit default constructor) not viable: requires 0 arguments, but 1 was provided}} + int A; + union { + float F; + int4 G; + }; +}; + +// casting types which contain bitfields is not yet supported. +export void cantCast() { + S s = (S)1; + // expected-error@-1 {{no matching conversion for C-style cast from 'int' to 'S'}} +} + +// Can't cast a union +export void cantCast2() { + R r = (R)1; + // expected-error@-1 {{no matching conversion for C-style cast from 'int' to 'R'}} +} >From 93c0450a22ad82a939f322e40d4c4e5a6b9d56dd Mon Sep 17 00:00:00 2001 From: Sarah Spall <sarahsp...@microsoft.com> Date: Mon, 10 Feb 2025 19:12:31 -0800 Subject: [PATCH 12/13] fix tests + associated issues in code --- clang/lib/Sema/SemaCast.cpp | 27 ++++++++++--------- clang/lib/Sema/SemaHLSL.cpp | 4 ++- .../CodeGenHLSL/BasicFeatures/SplatCast.hlsl | 8 +++--- clang/test/SemaHLSL/Language/SplatCasts.hlsl | 1 + 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp index a60bc0687461fa6..2f5deba7e12583b 100644 --- a/clang/lib/Sema/SemaCast.cpp +++ b/clang/lib/Sema/SemaCast.cpp @@ -2777,19 +2777,7 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle, ? CheckedConversionKind::FunctionalCast : CheckedConversionKind::CStyleCast; - // This case should not trigger on regular vector splat QualType SrcTy = SrcExpr.get()->getType(); - if (Self.getLangOpts().HLSL && - Self.HLSL().CanPerformSplatCast(SrcExpr.get(), DestType)) { - const VectorType *VT = SrcTy->getAs<VectorType>(); - // change splat from vec1 case to splat from scalar - if (VT && VT->getNumElements() == 1) - SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), VT->getElementType(), - CK_HLSLVectorTruncation, VK_PRValue, nullptr, CCK); - Kind = CK_HLSLSplatCast; - return; - } - // This case should not trigger on regular vector cast, vector truncation if (Self.getLangOpts().HLSL && Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) { @@ -2801,6 +2789,21 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle, return; } + // This case should not trigger on regular vector splat + // If the relative order of this and the HLSLElementWise cast checks + // are changed, it might change which cast handles what in a few cases + if (Self.getLangOpts().HLSL && + Self.HLSL().CanPerformSplatCast(SrcExpr.get(), DestType)) { + const VectorType *VT = SrcTy->getAs<VectorType>(); + // change splat from vec1 case to splat from scalar + if (VT && VT->getNumElements() == 1) + SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), VT->getElementType(), + CK_HLSLVectorTruncation, + SrcExpr.get()->getValueKind(), nullptr, CCK); + Kind = CK_HLSLSplatCast; + return; + } + if (ValueKind == VK_PRValue && !DestType->isRecordType() && !isPlaceholder(BuiltinType::Overload)) { SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get()); diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 252121f88af0507..9898732f30b2789 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2810,7 +2810,9 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) { bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) { QualType SrcTy = Src->getType(); - if (SrcTy->isScalarType() && DestTy->isVectorType()) + // Not a valid HLSL Splat cast if Dest is a scalar or if this is going to + // be a vector splat from a scalar. + if ((SrcTy->isScalarType() && DestTy->isVectorType()) || DestTy->isScalarType()) return false; const VectorType *SrcVecTy = SrcTy->getAs<VectorType>(); diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl index 2de68479179dd4c..0bc3e3fbd86cc33 100644 --- a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl +++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl @@ -20,9 +20,9 @@ export void call4() { // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false) // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0 // CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0 // CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1 -// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 // CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4 export void call8() { @@ -37,7 +37,7 @@ export void call8() { // CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16 // CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4 // CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4 -// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i64 0 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i32 0 // CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32 // CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0 // CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer @@ -58,9 +58,9 @@ struct S { // CHECK: [[s:%.*]] = alloca %struct.S, align 4 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0 // CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0 // CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1 -// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float // CHECK-NEXT: store float [[C]], ptr [[G2]], align 4 @@ -75,9 +75,9 @@ export void call3() { // CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0 // CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0 // CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1 -// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float // CHECK-NEXT: store float [[C]], ptr [[G2]], align 4 diff --git a/clang/test/SemaHLSL/Language/SplatCasts.hlsl b/clang/test/SemaHLSL/Language/SplatCasts.hlsl index cfe3b981dc92cc9..c57a577e8929f84 100644 --- a/clang/test/SemaHLSL/Language/SplatCasts.hlsl +++ b/clang/test/SemaHLSL/Language/SplatCasts.hlsl @@ -3,6 +3,7 @@ // splat from vec1 to vec // CHECK-LABEL: call1 // CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLSplatCast> +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float' lvalue <HLSLVectorTruncation> part_of_explicit_cast // CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>' export void call1() { float1 A = {1.0}; >From 1d317f21a280aa89b39cfe578fee3f022044e432 Mon Sep 17 00:00:00 2001 From: Sarah Spall <sarahsp...@microsoft.com> Date: Mon, 10 Feb 2025 19:16:00 -0800 Subject: [PATCH 13/13] clang format --- clang/lib/CodeGen/CGExprAgg.cpp | 2 +- clang/lib/CodeGen/CGExprScalar.cpp | 3 ++- clang/lib/Sema/SemaCast.cpp | 6 +++--- clang/lib/Sema/SemaHLSL.cpp | 3 ++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index 36557ccb15f1a71..69e77667648d0ca 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -996,7 +996,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) { Address DestVal = Dest.getAddress(); SourceLocation Loc = E->getExprLoc(); - assert (RV.isScalar() && "RHS of HLSL splat cast must be a scalar."); + assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar."); llvm::Value *SrcVal = RV.getScalarVal(); EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc); break; diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index b09c18f4a1229a2..cd7d9c243fcb24c 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -2797,7 +2797,8 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { } case CK_HLSLSplatCast: { // This cast should only handle splatting from vectors of length 1. - // But in Sema a cast should have been inserted to convert the vec1 to a scalar + // But in Sema a cast should have been inserted to convert the vec1 to a + // scalar assert(DestTy->isVectorType() && "Destination type must be a vector."); auto *DestVecTy = DestTy->getAs<VectorType>(); QualType SrcTy = E->getType(); diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp index 2f5deba7e12583b..e733dbc1c0d6274 100644 --- a/clang/lib/Sema/SemaCast.cpp +++ b/clang/lib/Sema/SemaCast.cpp @@ -2797,9 +2797,9 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle, const VectorType *VT = SrcTy->getAs<VectorType>(); // change splat from vec1 case to splat from scalar if (VT && VT->getNumElements() == 1) - SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), VT->getElementType(), - CK_HLSLVectorTruncation, - SrcExpr.get()->getValueKind(), nullptr, CCK); + SrcExpr = Self.ImpCastExprToType( + SrcExpr.get(), VT->getElementType(), CK_HLSLVectorTruncation, + SrcExpr.get()->getValueKind(), nullptr, CCK); Kind = CK_HLSLSplatCast; return; } diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 9898732f30b2789..b68a589e6de81ab 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2812,7 +2812,8 @@ bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) { QualType SrcTy = Src->getType(); // Not a valid HLSL Splat cast if Dest is a scalar or if this is going to // be a vector splat from a scalar. - if ((SrcTy->isScalarType() && DestTy->isVectorType()) || DestTy->isScalarType()) + if ((SrcTy->isScalarType() && DestTy->isVectorType()) || + DestTy->isScalarType()) return false; const VectorType *SrcVecTy = SrcTy->getAs<VectorType>(); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits