https://github.com/spall updated https://github.com/llvm/llvm-project/pull/118992
>From e994824f3630ee8b224afceb6c14d980c9013112 Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Fri, 6 Dec 2024 05:14:17 +0000 Subject: [PATCH 1/9] splat cast wip --- clang/include/clang/AST/OperationKinds.def | 3 ++ clang/include/clang/Sema/SemaHLSL.h | 1 + clang/lib/CodeGen/CGExprAgg.cpp | 42 ++++++++++++++++++++++ clang/lib/CodeGen/CGExprScalar.cpp | 16 +++++++++ clang/lib/Sema/Sema.cpp | 1 + clang/lib/Sema/SemaCast.cpp | 9 ++++- clang/lib/Sema/SemaHLSL.cpp | 26 ++++++++++++++ 7 files changed, 97 insertions(+), 1 deletion(-) diff --git a/clang/include/clang/AST/OperationKinds.def b/clang/include/clang/AST/OperationKinds.def index b3dc7c3d8dc77e1..333fc7e1b18821e 100644 --- a/clang/include/clang/AST/OperationKinds.def +++ b/clang/include/clang/AST/OperationKinds.def @@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue) // Aggregate by Value cast (HLSL only). CAST_OPERATION(HLSLElementwiseCast) +// Splat cast for Aggregates (HLSL only). +CAST_OPERATION(HLSLSplatCast) + //===- Binary Operations -------------------------------------------------===// // Operators listed in order of precedence. // Note that additions to this should also update the StmtVisitor class, diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 6e8ca2e4710dec8..7508b149b0d81d0 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -144,6 +144,7 @@ class SemaHLSL : public SemaBase { bool CanPerformScalarCast(QualType SrcTy, QualType DestTy); bool ContainsBitField(QualType BaseTy); bool CanPerformElementwiseCast(Expr *Src, QualType DestType); + bool CanPerformSplat(Expr *Src, QualType DestType); ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg); QualType getInoutParameterType(QualType Ty); diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index c3f1cbed6b39f95..f26189bc4907cea 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -491,6 +491,33 @@ static bool isTrivialFiller(Expr *E) { return false; } +static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal, + QualType DestTy, llvm::Value *SrcVal, + QualType SrcTy, SourceLocation Loc) { + // Flatten our destination + SmallVector<QualType> DestTypes; // Flattened type + SmallVector<llvm::Value *, 4> IdxList; + SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList; + // ^^ Flattened accesses to DestVal we want to store into + CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, + DestTypes); + + if (const VectorType *VT = SrcTy->getAs<VectorType>()) { + assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast."); + + SrcTy = VT->getElementType(); + SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, + "vec.load"); + } + assert(SrcTy->isScalarType() && "Invalid HLSL splat cast."); + for(unsigned i = 0; i < StoreGEPList.size(); i ++) { + llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy, + DestTypes[i], + Loc); + CGF.PerformStore(StoreGEPList[i], Cast); + } +} + // emit a flat cast where the RHS is a scalar, including vector static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal, QualType DestTy, llvm::Value *SrcVal, @@ -963,6 +990,21 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) { case CK_HLSLArrayRValue: Visit(E->getSubExpr()); break; + case CK_HLSLSplatCast: { + Expr *Src = E->getSubExpr(); + QualType SrcTy = Src->getType(); + RValue RV = CGF.EmitAnyExpr(Src); + QualType DestTy = E->getType(); + Address DestVal = Dest.getAddress(); + SourceLocation Loc = E->getExprLoc(); + + if (RV.isScalar()) { + llvm::Value *SrcVal = RV.getScalarVal(); + EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc); + break; + } + llvm_unreachable("RHS of HLSL splat cast must be a scalar or vector."); + } case CK_HLSLElementwiseCast: { Expr *Src = E->getSubExpr(); QualType SrcTy = Src->getType(); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 80daed7e5395193..7dc2682bae42f2e 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -2795,6 +2795,22 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy); return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc"); } + case CK_HLSLSplatCast: { + assert(DestTy->isVectorType() && "Destination type must be a vector."); + auto *DestVecTy = DestTy->getAs<VectorType>(); + QualType SrcTy = E->getType(); + SourceLocation Loc = CE->getExprLoc(); + Value *V = Visit(const_cast<Expr *>(E)); + if (auto *VecTy = SrcTy->getAs<VectorType>()) { + assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast."); + V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load"); + SrcTy = VecTy->getElementType(); + } + assert(SrcTy->isScalarType() && "Invalid HLSL splat cast."); + Value *Cast = EmitScalarConversion(V, SrcTy, + DestVecTy->getElementType(), Loc); + return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat"); + } case CK_HLSLElementwiseCast: { RValue RV = CGF.EmitAnyExpr(E); SourceLocation Loc = CE->getExprLoc(); diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp index 15c18f9a4525b22..9eeefbb3c002329 100644 --- a/clang/lib/Sema/Sema.cpp +++ b/clang/lib/Sema/Sema.cpp @@ -709,6 +709,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty, case CK_ToVoid: case CK_NonAtomicToAtomic: case CK_HLSLArrayRValue: + case CK_HLSLSplatCast: break; } } diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp index 23be71ad8e2aebc..56d8396b1e9d41a 100644 --- a/clang/lib/Sema/SemaCast.cpp +++ b/clang/lib/Sema/SemaCast.cpp @@ -2776,9 +2776,16 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle, CheckedConversionKind CCK = FunctionalStyle ? CheckedConversionKind::FunctionalCast : CheckedConversionKind::CStyleCast; + // This case should not trigger on regular vector splat - // vector cast, vector truncation, or special hlsl splat cases QualType SrcTy = SrcExpr.get()->getType(); + if (Self.getLangOpts().HLSL && + Self.HLSL().CanPerformSplat(SrcExpr.get(), DestType)) { + Kind = CK_HLSLSplatCast; + return; + } + + // This case should not trigger on regular vector cast, vector truncation if (Self.getLangOpts().HLSL && Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) { if (SrcTy->isConstantArrayType()) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index ec6b5b45de42bfa..7c9365787fd4fb5 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2804,6 +2804,32 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) { return false; } +// Can perform an HLSL splat cast if the Dest is an aggregate and the +// Src is a scalar or a vector of length 1 +// Or if Dest is a vector and Src is a vector of length 1 +bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) { + + QualType SrcTy = Src->getType(); + if (SrcTy->isScalarType() && DestTy->isVectorType()) + return false; + + const VectorType *SrcVecTy = SrcTy->getAs<VectorType>(); + if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1))) + return false; + + if (SrcVecTy) + SrcTy = SrcVecTy->getElementType(); + + llvm::SmallVector<QualType> DestTypes; + BuildFlattenedTypeList(DestTy, DestTypes); + + for(unsigned i = 0; i < DestTypes.size(); i ++) { + if (!CanPerformScalarCast(SrcTy, DestTypes[i])) + return false; + } + return true; +} + // Can we perform an HLSL Elementwise cast? // TODO: update this code when matrices are added; see issue #88060 bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) { >From 24bea86dd7a2c39ca9f21480990236dc44df8cf3 Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Fri, 6 Dec 2024 05:19:00 +0000 Subject: [PATCH 2/9] make clang format happy --- clang/lib/CodeGen/CGExprAgg.cpp | 19 ++++++++----------- clang/lib/CodeGen/CGExprScalar.cpp | 7 ++++--- clang/lib/Sema/SemaHLSL.cpp | 2 +- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index f26189bc4907cea..60beabf3a5fd0aa 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -492,28 +492,25 @@ static bool isTrivialFiller(Expr *E) { } static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal, - QualType DestTy, llvm::Value *SrcVal, - QualType SrcTy, SourceLocation Loc) { + QualType DestTy, llvm::Value *SrcVal, + QualType SrcTy, SourceLocation Loc) { // Flatten our destination SmallVector<QualType> DestTypes; // Flattened type SmallVector<llvm::Value *, 4> IdxList; SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList; // ^^ Flattened accesses to DestVal we want to store into - CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, - DestTypes); + CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes); if (const VectorType *VT = SrcTy->getAs<VectorType>()) { assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast."); SrcTy = VT->getElementType(); - SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, - "vec.load"); + SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, "vec.load"); } assert(SrcTy->isScalarType() && "Invalid HLSL splat cast."); - for(unsigned i = 0; i < StoreGEPList.size(); i ++) { - llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy, - DestTypes[i], - Loc); + for (unsigned i = 0; i < StoreGEPList.size(); i++) { + llvm::Value *Cast = + CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc); CGF.PerformStore(StoreGEPList[i], Cast); } } @@ -997,7 +994,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) { QualType DestTy = E->getType(); Address DestVal = Dest.getAddress(); SourceLocation Loc = E->getExprLoc(); - + if (RV.isScalar()) { llvm::Value *SrcVal = RV.getScalarVal(); EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 7dc2682bae42f2e..4a20b693b101fae 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -2807,9 +2807,10 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { SrcTy = VecTy->getElementType(); } assert(SrcTy->isScalarType() && "Invalid HLSL splat cast."); - Value *Cast = EmitScalarConversion(V, SrcTy, - DestVecTy->getElementType(), Loc); - return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat"); + Value *Cast = + EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc); + return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, + "splat"); } case CK_HLSLElementwiseCast: { RValue RV = CGF.EmitAnyExpr(E); diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 7c9365787fd4fb5..024f778f8ffef5b 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2823,7 +2823,7 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) { llvm::SmallVector<QualType> DestTypes; BuildFlattenedTypeList(DestTy, DestTypes); - for(unsigned i = 0; i < DestTypes.size(); i ++) { + for (unsigned i = 0; i < DestTypes.size(); i++) { if (!CanPerformScalarCast(SrcTy, DestTypes[i])) return false; } >From 3575617d436f04eac4faadc17ead8bfe561e7e7c Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Fri, 6 Dec 2024 05:59:12 +0000 Subject: [PATCH 3/9] codegen test --- .../CodeGenHLSL/BasicFeatures/SplatCast.hlsl | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl new file mode 100644 index 000000000000000..05359c1bce0ba35 --- /dev/null +++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl @@ -0,0 +1,87 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s + +// array splat +// CHECK-LABEL: define void {{.*}}call4 +// CHECK: [[B:%.*]] = alloca [2 x i32], align 4 +// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false) +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1 +// CHECK-NEXT: store i32 3, ptr [[G1]], align 4 +// CHECK-NEXT: store i32 3, ptr [[G2]], align 4 +export void call4() { + int B[2] = {1,2}; + B = (int[2])3; +} + +// splat from vector of length 1 +// CHECK-LABEL: define void {{.*}}call8 +// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4 +// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4 +// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 +// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false) +// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 +// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 +// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4 +export void call8() { + int1 A = {1}; + int B[2] = {1,2}; + B = (int[2])A; +} + +// vector splat from vector of length 1 +// CHECK-LABEL: define void {{.*}}call1 +// CHECK: [[B:%.*]] = alloca <1 x float>, align 4 +// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16 +// CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4 +// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i64 0 +// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32 +// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0 +// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer +// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 16 +export void call1() { + float1 B = {1.0}; + int4 A = (int4)B; +} + +struct S { + int X; + float Y; +}; + +// struct splats? +// CHECK-LABEL: define void {{.*}}call3 +// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4 +// CHECK: [[s:%.*]] = alloca %struct.S, align 4 +// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 +// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 +// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 +// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float +// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4 +export void call3() { + int1 A = {1}; + S s = (S)A; +} + +// struct splat from vector of length 1 +// CHECK-LABEL: define void {{.*}}call5 +// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4 +// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4 +// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 +// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1 +// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 +// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 +// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float +// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4 +export void call5() { + int1 A = {1}; + S s = (S)A; +} >From 288b8dac1c6fa4429c92c566a69da593c2ebb97c Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Fri, 6 Dec 2024 17:38:58 +0000 Subject: [PATCH 4/9] Try to handle Cast in all the places it needs to be handled --- clang/lib/AST/Expr.cpp | 1 + clang/lib/AST/ExprConstant.cpp | 2 ++ clang/lib/CodeGen/CGExprAgg.cpp | 1 + clang/lib/CodeGen/CGExprComplex.cpp | 1 + clang/lib/CodeGen/CGExprConstant.cpp | 1 + clang/lib/Edit/RewriteObjCFoundationAPI.cpp | 1 + clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp | 1 + 7 files changed, 8 insertions(+) diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp index c22aa66ba2cfb3d..bbb475fbb30f269 100644 --- a/clang/lib/AST/Expr.cpp +++ b/clang/lib/AST/Expr.cpp @@ -1957,6 +1957,7 @@ bool CastExpr::CastConsistency() const { case CK_HLSLArrayRValue: case CK_HLSLVectorTruncation: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: CheckNoBasePath: assert(path_empty() && "Cast kind should not have a base path!"); break; diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index 192b679b4c99596..ddc2d008839007e 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -15029,6 +15029,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) { case CK_FixedPointCast: case CK_IntegralToFixedPoint: case CK_MatrixCast: + case CK_HLSLSplatCast: llvm_unreachable("invalid cast kind for integral value"); case CK_BitCast: @@ -15907,6 +15908,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) { case CK_MatrixCast: case CK_HLSLVectorTruncation: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: llvm_unreachable("invalid cast kind for complex value"); case CK_LValueToRValue: diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index 60beabf3a5fd0aa..3584280e2fb9e44 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -1592,6 +1592,7 @@ static bool castPreservesZero(const CastExpr *CE) { case CK_AtomicToNonAtomic: case CK_HLSLVectorTruncation: case CK_HLSLElementwiseCast: + // TODO is this true for CK_HLSLSplatCast return true; case CK_BaseToDerivedMemberPointer: diff --git a/clang/lib/CodeGen/CGExprComplex.cpp b/clang/lib/CodeGen/CGExprComplex.cpp index c2679ea92dc9728..3832b9b598b24e9 100644 --- a/clang/lib/CodeGen/CGExprComplex.cpp +++ b/clang/lib/CodeGen/CGExprComplex.cpp @@ -611,6 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op, case CK_HLSLVectorTruncation: case CK_HLSLArrayRValue: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: llvm_unreachable("invalid cast kind for complex value"); case CK_FloatingRealToComplex: diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp index ef11798869d3b13..b8ce83803b65fde 100644 --- a/clang/lib/CodeGen/CGExprConstant.cpp +++ b/clang/lib/CodeGen/CGExprConstant.cpp @@ -1336,6 +1336,7 @@ class ConstExprEmitter case CK_HLSLVectorTruncation: case CK_HLSLArrayRValue: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: return nullptr; } llvm_unreachable("Invalid CastKind"); diff --git a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp index 32f5ebb55155ed1..10d3f62fcd0a416 100644 --- a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp +++ b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp @@ -1086,6 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg, case CK_HLSLVectorTruncation: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: llvm_unreachable("HLSL-specific cast in Objective-C?"); break; diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp index 3a983421358c7f4..d75583f68eb6b7b 100644 --- a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp +++ b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp @@ -523,6 +523,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex, case CK_MatrixCast: case CK_VectorSplat: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: case CK_HLSLVectorTruncation: { QualType resultType = CastE->getType(); if (CastE->isGLValue()) >From 0650840642960d950d64e234e9641e34096a6c55 Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Wed, 11 Dec 2024 20:54:39 +0000 Subject: [PATCH 5/9] get code compiling after rebase --- clang/lib/CodeGen/CGExprAgg.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index 3584280e2fb9e44..3330cd03628f75e 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -496,10 +496,9 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal, QualType SrcTy, SourceLocation Loc) { // Flatten our destination SmallVector<QualType> DestTypes; // Flattened type - SmallVector<llvm::Value *, 4> IdxList; SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList; // ^^ Flattened accesses to DestVal we want to store into - CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes); + CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes); if (const VectorType *VT = SrcTy->getAs<VectorType>()) { assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast."); @@ -511,7 +510,15 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal, for (unsigned i = 0; i < StoreGEPList.size(); i++) { llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc); - CGF.PerformStore(StoreGEPList[i], Cast); + + // store back + llvm::Value *Idx = StoreGEPList[i].second; + if (Idx) { + llvm::Value *V = + CGF.Builder.CreateLoad(StoreGEPList[i].first, "load.for.insert"); + Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx); + } + CGF.Builder.CreateStore(Cast, StoreGEPList[i].first); } } >From f924b13ada0c3344f3cc4f87a859f0ecd16705cb Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Thu, 12 Dec 2024 00:04:29 +0000 Subject: [PATCH 6/9] Self review --- clang/lib/CodeGen/CGExprScalar.cpp | 15 +++++++----- clang/lib/Sema/SemaHLSL.cpp | 7 +++--- clang/test/SemaHLSL/Language/SplatCasts.hlsl | 25 ++++++++++++++++++++ 3 files changed, 38 insertions(+), 9 deletions(-) create mode 100644 clang/test/SemaHLSL/Language/SplatCasts.hlsl diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 4a20b693b101fae..85c0265ea14b611 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -2796,17 +2796,20 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc"); } case CK_HLSLSplatCast: { + // This code should only handle splatting from vectors of length 1. assert(DestTy->isVectorType() && "Destination type must be a vector."); auto *DestVecTy = DestTy->getAs<VectorType>(); QualType SrcTy = E->getType(); SourceLocation Loc = CE->getExprLoc(); Value *V = Visit(const_cast<Expr *>(E)); - if (auto *VecTy = SrcTy->getAs<VectorType>()) { - assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast."); - V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load"); - SrcTy = VecTy->getElementType(); - } - assert(SrcTy->isScalarType() && "Invalid HLSL splat cast."); + assert(SrcTy->isVectorType() && "Invalid HLSL splat cast."); + + auto *VecTy = SrcTy->getAs<VectorType>(); + assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast."); + + V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load"); + SrcTy = VecTy->getElementType(); + Value *Cast = EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc); return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 024f778f8ffef5b..432a42016789ec2 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2814,12 +2814,13 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) { return false; const VectorType *SrcVecTy = SrcTy->getAs<VectorType>(); - if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1))) - return false; - if (SrcVecTy) SrcTy = SrcVecTy->getElementType(); + // Src isn't a scalar or a vector of length 1 + if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1)) + return false; + llvm::SmallVector<QualType> DestTypes; BuildFlattenedTypeList(DestTy, DestTypes); diff --git a/clang/test/SemaHLSL/Language/SplatCasts.hlsl b/clang/test/SemaHLSL/Language/SplatCasts.hlsl new file mode 100644 index 000000000000000..593a8e67fd4a3b8 --- /dev/null +++ b/clang/test/SemaHLSL/Language/SplatCasts.hlsl @@ -0,0 +1,25 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -finclude-default-header -fnative-half-type %s -ast-dump | FileCheck %s + +// splat from vec1 to vec +// CHECK-LABEL: call1 +// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLSplatCast> +// CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>' +export void call1() { + float1 A = {1.0}; + int3 B = (int3)A; +} + +struct S { + int A; + float B; + int C; + float D; +}; + +// splat from scalar to aggregate +// CHECK-LABEL: call2 +// CHECK: CStyleCastExpr {{.*}} 'S' <HLSLSplatCast> +// CHECK-NEXt: IntegerLiteral {{.*}} 'int' 5 +export void call2() { + S s = (S)5; +} \ No newline at end of file >From 89ceeb7d6b445f10fa6b7deb8c10267cd292da7b Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Thu, 12 Dec 2024 05:59:55 +0000 Subject: [PATCH 7/9] move code back that broke tests --- clang/lib/Sema/SemaHLSL.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 432a42016789ec2..76ca24b10c60a16 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2814,13 +2814,14 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) { return false; const VectorType *SrcVecTy = SrcTy->getAs<VectorType>(); - if (SrcVecTy) - SrcTy = SrcVecTy->getElementType(); // Src isn't a scalar or a vector of length 1 if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1)) return false; + if (SrcVecTy) + SrcTy = SrcVecTy->getElementType(); + llvm::SmallVector<QualType> DestTypes; BuildFlattenedTypeList(DestTy, DestTypes); >From 7f5b3e4f39f2a4cf2d42e5281e70d900878c1a3b Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Thu, 12 Dec 2024 06:08:46 +0000 Subject: [PATCH 8/9] fix tests --- .../CodeGenHLSL/BasicFeatures/SplatCast.hlsl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl index 05359c1bce0ba35..2de68479179dd4c 100644 --- a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl +++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl @@ -4,8 +4,8 @@ // CHECK-LABEL: define void {{.*}}call4 // CHECK: [[B:%.*]] = alloca [2 x i32], align 4 // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false) -// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0 -// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1 // CHECK-NEXT: store i32 3, ptr [[G1]], align 4 // CHECK-NEXT: store i32 3, ptr [[G2]], align 4 export void call4() { @@ -20,8 +20,8 @@ export void call4() { // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false) // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 -// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0 -// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 // CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4 @@ -58,8 +58,8 @@ struct S { // CHECK: [[s:%.*]] = alloca %struct.S, align 4 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 -// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0 -// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float @@ -75,8 +75,8 @@ export void call3() { // CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4 -// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0 -// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1 +// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0 +// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float >From 844ba82eb5dcfbd0105db2d4943266fa8d009c17 Mon Sep 17 00:00:00 2001 From: Sarah Spall <sarahsp...@microsoft.com> Date: Sat, 8 Feb 2025 09:07:05 -0800 Subject: [PATCH 9/9] add cast to cases --- clang/lib/CodeGen/CGExpr.cpp | 1 + clang/lib/CodeGen/CGExprAgg.cpp | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 2bbc0791c65876f..545d8b11a6a47a9 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -5339,6 +5339,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) { case CK_HLSLVectorTruncation: case CK_HLSLArrayRValue: case CK_HLSLElementwiseCast: + case CK_HLSLSplatCast: return EmitUnsupportedLValue(E, "unexpected cast lvalue"); case CK_Dependent: diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index 3330cd03628f75e..b7fe62687b074a0 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -1599,7 +1599,7 @@ static bool castPreservesZero(const CastExpr *CE) { case CK_AtomicToNonAtomic: case CK_HLSLVectorTruncation: case CK_HLSLElementwiseCast: - // TODO is this true for CK_HLSLSplatCast + case CK_HLSLSplatCast: return true; case CK_BaseToDerivedMemberPointer: _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits