================
@@ -5161,6 +5166,157 @@ bool Sema::CheckPPCMMAType(QualType Type, 
SourceLocation TypeLoc) {
   return false;
 }
 
+// Helper function for CheckHLSLBuiltinFunctionCall
+// Note: UsualArithmeticConversions handles the case where at least
+// one arg isn't a bool
+bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
+  unsigned NumArgs = TheCall->getNumArgs();
+
+  for (unsigned i = 0; i < NumArgs; ++i) {
+    ExprResult A = TheCall->getArg(i);
+    if (!A.get()->getType()->isBooleanType())
+      return false;
+  }
+  // if we got here all args are bool
+  for (unsigned i = 0; i < NumArgs; ++i) {
+    ExprResult A = TheCall->getArg(i);
+    ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
+                                                   Sema::AA_Converting);
+    if (ResA.isInvalid())
+      return true;
+    TheCall->setArg(i, ResA.get());
+  }
+  return false;
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+// Handles the CK_HLSLVectorTruncation case for builtins
+void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() > 1);
+  ExprResult A = TheCall->getArg(0);
+  ExprResult B = TheCall->getArg(1);
+  QualType ArgTyA = A.get()->getType();
+  QualType ArgTyB = B.get()->getType();
+
+  auto *VecTyA = ArgTyA->getAs<VectorType>();
+  auto *VecTyB = ArgTyB->getAs<VectorType>();
+  if (VecTyA == nullptr && VecTyB == nullptr)
+    return;
+  if (VecTyA == nullptr || VecTyB == nullptr)
+    return;
+  if (VecTyA->getNumElements() == VecTyB->getNumElements())
+    return;
+
+  Expr *LargerArg = B.get();
+  Expr *SmallerArg = A.get();
+  int largerIndex = 1;
+  if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
+    LargerArg = A.get();
+    SmallerArg = B.get();
+    largerIndex = 0;
+  }
+
+  S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
+      << LargerArg->getType() << SmallerArg->getType()
+      << LargerArg->getSourceRange() << SmallerArg->getSourceRange();
+  ExprResult ResLargerArg = S->ImpCastExprToType(
+      LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
+  TheCall->setArg(largerIndex, ResLargerArg.get());
+  return;
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+void CheckVectorFloatPromotion(Sema *S, ExprResult &source, QualType targetTy,
+                               SourceRange targetSrcRange,
+                               SourceLocation BuiltinLoc) {
+  auto *vecTyTarget = source.get()->getType()->getAs<VectorType>();
+  assert(vecTyTarget);
+  QualType vecElemT = vecTyTarget->getElementType();
+  if (!vecElemT->isFloatingType() && targetTy->isFloatingType()) {
+    QualType floatVecTy = S->Context.getVectorType(
+        S->Context.FloatTy, vecTyTarget->getNumElements(), 
VectorKind::Generic);
+
+    S->Diag(BuiltinLoc, diag::warn_impcast_integer_float_precision)
+        << source.get()->getType() << floatVecTy
+        << source.get()->getSourceRange() << targetSrcRange;
+    source = S->SemaConvertVectorExpr(
+        source.get(), S->Context.CreateTypeSourceInfo(floatVecTy), BuiltinLoc,
+        source.get()->getBeginLoc());
+  }
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+void PromoteVectorArgSplat(Sema *S, ExprResult &source, QualType targetTy) {
+  QualType sourceTy = source.get()->getType();
+  auto *vecTyTarget = targetTy->getAs<VectorType>();
+  QualType vecElemT = vecTyTarget->getElementType();
+  if (vecElemT->isFloatingType() && sourceTy != vecElemT)
+    // if float vec splat wil do an unnecessary cast to double
+    source = S->ImpCastExprToType(source.get(), vecElemT, CK_FloatingCast);
+  source = S->ImpCastExprToType(source.get(), targetTy, CK_VectorSplat);
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() > 1);
+  ExprResult A = TheCall->getArg(0);
+  ExprResult B = TheCall->getArg(1);
+  QualType ArgTyA = A.get()->getType();
+  QualType ArgTyB = B.get()->getType();
+  auto *VecTyA = ArgTyA->getAs<VectorType>();
+  auto *VecTyB = ArgTyB->getAs<VectorType>();
+
+  if (VecTyA == nullptr && VecTyB == nullptr)
+    return false;
+
+  if (VecTyA && VecTyB) {
+    if (VecTyA->getElementType() == VecTyB->getElementType()) {
+      TheCall->setType(VecTyA->getElementType());
+      return false;
+    }
+    // Note: type promotion is intended to be handeled via the intrinsics
+    //  and not the builtin itself.
+    S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
+        << TheCall->getDirectCallee()
+        << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
+    return true;
+  }
+
+  if (VecTyB) {
+    CheckVectorFloatPromotion(S, B, ArgTyA, A.get()->getSourceRange(),
+                              TheCall->getBeginLoc());
+    PromoteVectorArgSplat(S, A, B.get()->getType());
+  }
+  if (VecTyA) {
+    CheckVectorFloatPromotion(S, A, ArgTyB, B.get()->getSourceRange(),
+                              TheCall->getBeginLoc());
+    PromoteVectorArgSplat(S, B, A.get()->getType());
+  }
+  TheCall->setArg(0, A.get());
+  TheCall->setArg(1, B.get());
+  return false;
+}
+
+// Note: returning true in this case results in CheckBuiltinFunctionCall
+// returning an ExprError
+bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) 
{
+  switch (BuiltinID) {
+  case Builtin::BI__builtin_hlsl_dot: {
+    if (checkArgCount(*this, TheCall, 2))
+      return true;
+    if (PromoteBoolsToInt(this, TheCall))
+      return true;
+    if (CheckVectorElementCallArgs(this, TheCall))
+      return true;
+    PromoteVectorArgTruncation(this, TheCall);
+    if (SemaBuiltinVectorToScalarMath(TheCall))
----------------
llvm-beanz wrote:

We shouldn’t be testing overload resolution in this change. We should test 
overload resolution separately as we implement HLSL’s overload resolution 
rules. It isn’t possible to get correct overload resolution until we implement 
it, and adding a bunch of complexity to try and test something that isn’t 
implemented seems wrong.

The call to `dot` should not be ambiguous in any cases where exact-match 
overloads exist, so we should be able to test code generation for the intrinsic 
in the exact match cases, and have comprehensive overload resolution tests 
later when that gets implemented.

DXC implemented overload resolution logic multiple times just like this which 
leads to confusing and subtle bugs. It would be best to not make that mistake 
again.

https://github.com/llvm/llvm-project/pull/81190
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to