llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang-codegen Author: Chris B (llvm-beanz) <details> <summary>Changes</summary> This PR adds scalar/vector overloads for vector conditions to the `select` builtin, and updates the sema checking and codegen to allow scalars to extend to vectors. Fixes #<!-- -->126570 --- Full diff: https://github.com/llvm/llvm-project/pull/129396.diff 7 Files Affected: - (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+3) - (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8) - (modified) clang/lib/Headers/hlsl/hlsl_detail.h (+5) - (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+36) - (modified) clang/lib/Sema/SemaHLSL.cpp (+32-24) - (modified) clang/test/CodeGenHLSL/builtins/select.hlsl (+29) - (modified) clang/test/SemaHLSL/BuiltIns/select-errors.hlsl (+22-76) ``````````diff diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index d094c075ecee2..be649f0bce320 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -12682,6 +12682,9 @@ def err_hlsl_param_qualifier_mismatch : def err_hlsl_vector_compound_assignment_truncation : Error< "left hand operand of type %0 to compound assignment cannot be truncated " "when used with right hand operand of type %1">; +def err_hlsl_builtin_scalar_vector_mismatch : Error< + "%select{all|second and third}0 arguments to %1 must be of scalar or " + "vector type with matching scalar element type%diff{: $ vs $|}2,3">; def warn_hlsl_impcast_vector_truncation : Warning< "implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>; diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 03b8d16b76e0d..a84e5e4b59c89 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -19741,6 +19741,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: { RValFalse.isScalar() ? RValFalse.getScalarVal() : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this); + if (auto *VTy = E->getType()->getAs<VectorType>()) { + if (!OpTrue->getType()->isVectorTy()) + OpTrue = + Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat"); + if (!OpFalse->getType()->isVectorTy()) + OpFalse = + Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat"); + } Value *SelectVal = Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select"); diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h index 0d568539cd66a..daccd2d793aa8 100644 --- a/clang/lib/Headers/hlsl/hlsl_detail.h +++ b/clang/lib/Headers/hlsl/hlsl_detail.h @@ -95,6 +95,11 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) { #endif } +template<typename T> +struct is_arithmetic { + static const bool Value = __is_arithmetic(T); +}; + } // namespace __detail } // namespace hlsl #endif //_HLSL_HLSL_DETAILS_H_ diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index ed008eeb04ba8..77a7f773b85b2 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -2246,6 +2246,42 @@ template <typename T, int Sz> _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select) vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>); + +/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, T TrueVal, +/// vector<T,Sz> FalseVals) +/// \brief ternary operator for vectors. All vectors must be the same size. +/// \param Conds The Condition input values. +/// \param TrueVal The scalar value to splat from when conditions are true. +/// \param FalseVals The vector values are chosen from when conditions are +/// false. + +template <typename T, int Sz> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select) +vector<T, Sz> select(vector<bool, Sz>, T, vector<T, Sz>); + +/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals, +/// T FalseVal) +/// \brief ternary operator for vectors. All vectors must be the same size. +/// \param Conds The Condition input values. +/// \param TrueVals The vector values are chosen from when conditions are true. +/// \param FalseVal The scalar value to splat from when conditions are false. + +template <typename T, int Sz> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select) +vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, T); + +/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals, +/// T FalseVal) +/// \brief ternary operator for vectors. All vectors must be the same size. +/// \param Conds The Condition input values. +/// \param TrueVal The scalar value to splat from when conditions are true. +/// \param FalseVal The scalar value to splat from when conditions are false. + +template <typename T, int Sz> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select) +__detail::enable_if_t<__detail::is_arithmetic<T>::Value, vector<T, Sz>> select( + vector<bool, Sz>, T, T); + //===----------------------------------------------------------------------===// // sin builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index bfe84b16218b7..4ec31cd39eb60 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2213,40 +2213,48 @@ static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) { static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) { assert(TheCall->getNumArgs() == 3); Expr *Arg1 = TheCall->getArg(1); + QualType Arg1Ty = Arg1->getType(); Expr *Arg2 = TheCall->getArg(2); - if (!Arg1->getType()->isVectorType()) { - S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type) - << "Second" << TheCall->getDirectCallee() << Arg1->getType() + QualType Arg2Ty = Arg2->getType(); + + QualType Arg1ScalarTy = Arg1Ty; + if (auto VTy = Arg1ScalarTy->getAs<VectorType>()) + Arg1ScalarTy = VTy->getElementType(); + + QualType Arg2ScalarTy = Arg2Ty; + if (auto VTy = Arg2ScalarTy->getAs<VectorType>()) + Arg2ScalarTy = VTy->getElementType(); + + if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy)) + S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch) + << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty; + + QualType Arg0Ty = TheCall->getArg(0)->getType(); + unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements(); + unsigned Arg1Length = Arg1Ty->isVectorType() + ? Arg1Ty->getAs<VectorType>()->getNumElements() + : 0; + unsigned Arg2Length = Arg2Ty->isVectorType() + ? Arg2Ty->getAs<VectorType>()->getNumElements() + : 0; + if (Arg1Length > 0 && Arg0Length != Arg1Length) { + S->Diag(TheCall->getBeginLoc(), + diag::err_typecheck_vector_lengths_not_equal) + << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange(); return true; } - if (!Arg2->getType()->isVectorType()) { - S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type) - << "Third" << TheCall->getDirectCallee() << Arg2->getType() - << Arg2->getSourceRange(); - return true; - } - - if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) { + if (Arg2Length > 0 && Arg0Length != Arg2Length) { S->Diag(TheCall->getBeginLoc(), - diag::err_typecheck_call_different_arg_types) - << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange() + diag::err_typecheck_vector_lengths_not_equal) + << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange() << Arg2->getSourceRange(); return true; } - // caller has checked that Arg0 is a vector. - // check all three args have the same length. - if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() != - Arg1->getType()->getAs<VectorType>()->getNumElements()) { - S->Diag(TheCall->getBeginLoc(), - diag::err_typecheck_vector_lengths_not_equal) - << TheCall->getArg(0)->getType() << Arg1->getType() - << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange(); - return true; - } - TheCall->setType(Arg1->getType()); + TheCall->setType( + S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length)); return false; } diff --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl index cade938b71a2b..196b8a90cd877 100644 --- a/clang/test/CodeGenHLSL/builtins/select.hlsl +++ b/clang/test/CodeGenHLSL/builtins/select.hlsl @@ -52,3 +52,32 @@ int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) { int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) { return select(cond0, tVals, fVals); } + +// CHECK-LABEL: test_select_vector_scalar_vector +// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0 +// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer +// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> {{%.*}} +// CHECK: ret <4 x i32> [[SELECT]] +int4 test_select_vector_scalar_vector(bool4 cond0, int tVal, int4 fVals) { + return select(cond0, tVal, fVals); +} + +// CHECK-LABEL: test_select_vector_vector_scalar +// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0 +// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer +// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> [[SPLAT1]] +// CHECK: ret <4 x i32> [[SELECT]] +int4 test_select_vector_vector_scalar(bool4 cond0, int4 tVals, int fVal) { + return select(cond0, tVals, fVal); +} + +// CHECK-LABEL: test_select_vector_scalar_scalar +// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0 +// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer +// CHECK: [[SPLAT_SRC2:%.*]] = insertelement <4 x i32> poison, i32 %3, i64 0 +// CHECK: [[SPLAT2:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC2]], <4 x i32> poison, <4 x i32> zeroinitializer +// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> [[SPLAT2]] +// CHECK: ret <4 x i32> [[SELECT]] +int4 test_select_vector_scalar_scalar(bool4 cond0, int tVal, int fVal) { + return select(cond0, tVal, fVal); +} diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl index 34b5fb6d54cd5..b445cedcba074 100644 --- a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl @@ -1,119 +1,65 @@ -// RUN: %clang_cc1 -finclude-default-header -// -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -// -disable-llvm-passes -verify -verify-ignore-unexpected +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify -int test_no_arg() { - return select(); - // expected-error@-1 {{no matching function for call to 'select'}} - // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template - // not viable: requires 3 arguments, but 0 were provided}} - // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not - // viable: requires 3 arguments, but 0 were provided}} -} - -int test_too_few_args(bool p0) { - return select(p0); - // expected-error@-1 {{no matching function for call to 'select'}} - // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not - // viable: requires 3 arguments, but 1 was provided}} - // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not - // viable: requires 3 arguments, but 1 was provided}} -} - -int test_too_many_args(bool p0, int t0, int f0, int g0) { - return select<int>(p0, t0, f0, g0); - // expected-error@-1 {{no matching function for call to 'select'}} - // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not - // viable: requires 3 arguments, but 4 were provided}} - // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not - // viable: requires 3 arguments, but 4 were provided}} -} int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) { return select(p0, t0, f0); - // expected-error@-1 {{no matching function for call to 'select'}} - // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not - // viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value) - // to 'bool' for 1st argument}} - // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could - // not match 'vector<T, Sz>' against 'int'}} } int1 test_select_bool_vals_diff_vecs(bool p0, int1 t0, int1 f0) { return select<int1>(p0, t0, f0); - // expected-warning@-1 {{implicit conversion truncates vector: - // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>' - // (vector of 1 'int' value)}} } int2 test_select_vector_vals_not_vecs(bool2 p0, int t0, int f0) { return select(p0, t0, f0); - // expected-error@-1 {{no matching function for call to 'select'}} - // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored: - // could not match 'vector<T, Sz>' against 'int'}} - // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not - // viable: no known conversion from 'vector<bool, 2>' - // (vector of 2 'bool' values) to 'bool' for 1st argument}} } int1 test_select_vector_vals_wrong_size(bool2 p0, int1 t0, int1 f0) { - return select<int,1>(p0, t0, f0); // produce warnings - // expected-warning@-1 {{implicit conversion truncates vector: - // 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>' - // (vector of 1 'bool' value)}} - // expected-warning@-2 {{implicit conversion truncates vector: - // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>' - // (vector of 1 'int' value)}} + return select<int,1>(p0, t0, f0); // expected-warning{{implicit conversion truncates vector: 'bool2' (aka 'vector<bool, 2>') to 'vector<bool, 1>' (vector of 1 'bool' value)}} +} + +int test_select_no_args() { + return __builtin_hlsl_select(); // expected-error{{too few arguments to function call, expected 3, have 0}} +} + +int test_select_builtin_wrong_arg_count(bool p0) { + return __builtin_hlsl_select(p0); // expected-error{{too few arguments to function call, expected 3, have 1}} } // __builtin_hlsl_select tests -int test_select_builtin_wrong_arg_count(bool p0, int t0) { - return __builtin_hlsl_select(p0, t0); - // expected-error@-1 {{too few arguments to function call, expected 3, - // have 2}} +int test_select_builtin_wrong_arg_count2(bool p0, int t0) { + return __builtin_hlsl_select(p0, t0); // expected-error{{too few arguments to function call, expected 3, have 2}} +} + +int test_too_many_args(bool p0, int t0, int f0, int g0) { + return __builtin_hlsl_select(p0, t0, f0, g0); // expected-error{{too many arguments to function call, expected 3, have 4}} } // not a bool or a vector of bool. should be 2 errors. int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) { - return __builtin_hlsl_select(p0, t0, f0); - // expected-error@-1 {{passing 'int' to parameter of incompatible type - // 'bool'}} - // expected-error@-2 {{First argument to __builtin_hlsl_select must be of - // vector type}} - } + return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int' where 'bool' or a vector of such type is required}} +} int test_select_builtin_first_arg_wrong_type2(int1 p0, int t0, int f0) { - return __builtin_hlsl_select(p0, t0, f0); - // expected-error@-1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to - // parameter of incompatible type 'bool'}} - // expected-error@-2 {{First argument to __builtin_hlsl_select must be of - // vector type}} + return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int1' (aka 'vector<int, 1>') where 'bool' or a vector of such type is required}} } // if a bool last 2 args are of same type int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) { - return __builtin_hlsl_select(p0, t0, f0); - // expected-error@-1 {{arguments are of different types ('int' vs 'double')}} + return __builtin_hlsl_select(p0, t0, f0); // expected-error{{arguments are of different types ('int' vs 'double')}} } // if a vector second arg isnt a vector int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) { return __builtin_hlsl_select(p0, t0, f0); - // expected-error@-1 {{Second argument to __builtin_hlsl_select must be of - // vector type}} } // if a vector third arg isn't a vector int2 test_select_builtin_second_arg_not_vector(bool2 p0, int2 t0, int f0) { return __builtin_hlsl_select(p0, t0, f0); - // expected-error@-1 {{Third argument to __builtin_hlsl_select must be of - // vector type}} } // if vector last 2 aren't same type (so both are vectors but wrong type) -int2 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) { - return __builtin_hlsl_select(p0, t0, f0); - // expected-error@-1 {{arguments are of different types ('vector<int, [...]>' - // vs 'vector<float, [...]>')}} +int1 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) { + return __builtin_hlsl_select(p0, t0, f0); // expected-error{{second and third arguments to __builtin_hlsl_select must be of scalar or vector type with matching scalar element type: 'vector<int, [...]>' vs 'vector<float, [...]>'}} } `````````` </details> https://github.com/llvm/llvm-project/pull/129396 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits