https://github.com/MarwanTarik created 
https://github.com/llvm/llvm-project/pull/171694

This PR is part of https://github.com/llvm/llvm-project/issues/167752. It 
upstreams the codegen and tests for the convert to mask builtins implemented in 
the incubator, including:

Upstream X86 mask conversion builtins from clangir:
- cvtmask2b/w/d/q*
- cvtb/w/d/q2mask* 

Upstreamed helpers:
- emitX86MaskedCompare()
- emitX86ConvertToMask()
- emitX86SExtMask()

>From 82529b8bfd35c9e8059b49e2f17b3c837232cf09 Mon Sep 17 00:00:00 2001
From: MarwanTarik <[email protected]>
Date: Wed, 10 Dec 2025 22:21:55 +0200
Subject: [PATCH] Upstream CIR Codgen for convert to mask X86 builtins

---
 clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp   | 114 +++++++++++++++++++
 clang/test/CodeGen/X86/avx512vlbw-builtins.c |  12 ++
 2 files changed, 126 insertions(+)

diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp 
b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
index fb17e31bf36d6..bba7249666aaf 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
@@ -231,6 +231,113 @@ static mlir::Value emitX86MaskTest(CIRGenBuilderTy 
&builder, mlir::Location loc,
   return emitIntrinsicCallOp(builder, loc, intrinsicName, resTy,
                              mlir::ValueRange{lhsVec, rhsVec});
 }
+static mlir::Value emitX86MaskedCompareResult(CIRGenFunction &cgf,
+                                              mlir::Value cmp, unsigned 
numElts,
+                                              mlir::Value maskIn,
+                                              mlir::Location loc) {
+  if (maskIn) {
+    llvm_unreachable("NYI");
+  }
+  if (numElts < 8) {
+    int64_t indices[8];
+    for (unsigned i = 0; i != numElts; ++i)
+      indices[i] = i;
+    for (unsigned i = numElts; i != 8; ++i)
+      indices[i] = i % numElts + numElts;
+
+    // This should shuffle between cmp (first vector) and null (second vector)
+    mlir::Value nullVec = cgf.getBuilder().getNullValue(cmp.getType(), loc);
+    cmp = cgf.getBuilder().createVecShuffle(loc, cmp, nullVec, indices);
+  }
+  return cgf.getBuilder().createBitcast(
+      cmp, cgf.getBuilder().getUIntNTy(std::max(numElts, 8U)));
+}
+
+static mlir::Value emitX86MaskedCompare(CIRGenFunction &cgf, unsigned cc,
+                                        bool isSigned,
+                                        ArrayRef<mlir::Value> ops,
+                                        mlir::Location loc) {
+  assert((ops.size() == 2 || ops.size() == 4) &&
+         "Unexpected number of arguments");
+  unsigned numElts = cast<cir::VectorType>(ops[0].getType()).getSize();
+  mlir::Value cmp;
+
+  if (cc == 3) {
+    llvm_unreachable("NYI");
+  } else if (cc == 7) {
+    llvm_unreachable("NYI");
+  } else {
+    cir::CmpOpKind pred;
+    switch (cc) {
+    default:
+      llvm_unreachable("Unknown condition code");
+    case 0:
+      pred = cir::CmpOpKind::eq;
+      break;
+    case 1:
+      pred = cir::CmpOpKind::lt;
+      break;
+    case 2:
+      pred = cir::CmpOpKind::le;
+      break;
+    case 4:
+      pred = cir::CmpOpKind::ne;
+      break;
+    case 5:
+      pred = cir::CmpOpKind::ge;
+      break;
+    case 6:
+      pred = cir::CmpOpKind::gt;
+      break;
+    }
+
+    auto resultTy = cgf.getBuilder().getType<cir::VectorType>(
+        cgf.getBuilder().getUIntNTy(1), numElts);
+    cmp = cir::VecCmpOp::create(cgf.getBuilder(), loc, resultTy, pred, ops[0],
+                                ops[1]);
+  }
+
+  mlir::Value maskIn;
+  if (ops.size() == 4)
+    maskIn = ops[3];
+
+  return emitX86MaskedCompareResult(cgf, cmp, numElts, maskIn, loc);
+}
+
+static mlir::Value emitX86ConvertToMask(CIRGenFunction &cgf, mlir::Value in,
+                                        mlir::Location loc) {
+  cir::ConstantOp zero = cgf.getBuilder().getNullValue(in.getType(), loc);
+  return emitX86MaskedCompare(cgf, 1, true, {in, zero}, loc);
+}
+
+// Convert the mask from an integer type to a vector of i1.
+static mlir::Value getMaskVecValue(CIRGenFunction &cgf, mlir::Value mask,
+                                   unsigned numElts, mlir::Location loc) {
+  cir::VectorType maskTy =
+      cir::VectorType::get(cgf.getBuilder().getSIntNTy(1),
+                           cast<cir::IntType>(mask.getType()).getWidth());
+
+  mlir::Value maskVec = cgf.getBuilder().createBitcast(mask, maskTy);
+
+  // If we have less than 8 elements, then the starting mask was an i8 and
+  // we need to extract down to the right number of elements.
+  if (numElts < 8) {
+    llvm::SmallVector<int64_t, 4> indices;
+    for (unsigned i = 0; i != numElts; ++i)
+      indices.push_back(i);
+    maskVec = cgf.getBuilder().createVecShuffle(loc, maskVec, maskVec, 
indices);
+  }
+
+  return maskVec;
+}
+
+static mlir::Value emitX86SExtMask(CIRGenFunction &cgf, mlir::Value op,
+                                   mlir::Type dstTy, mlir::Location loc) {
+  unsigned numberOfElements = cast<cir::VectorType>(dstTy).getSize();
+  mlir::Value mask = getMaskVecValue(cgf, op, numberOfElements, loc);
+
+  return cgf.getBuilder().createCast(loc, cir::CastKind::integral, mask, 
dstTy);
+}
 
 static mlir::Value emitVecInsert(CIRGenBuilderTy &builder, mlir::Location loc,
                                  mlir::Value vec, mlir::Value value,
@@ -558,6 +665,7 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned 
builtinID,
   case X86::BI__builtin_ia32_storesh128_mask:
   case X86::BI__builtin_ia32_storess128_mask:
   case X86::BI__builtin_ia32_storesd128_mask:
+
   case X86::BI__builtin_ia32_cvtmask2b128:
   case X86::BI__builtin_ia32_cvtmask2b256:
   case X86::BI__builtin_ia32_cvtmask2b512:
@@ -570,6 +678,8 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned 
builtinID,
   case X86::BI__builtin_ia32_cvtmask2q128:
   case X86::BI__builtin_ia32_cvtmask2q256:
   case X86::BI__builtin_ia32_cvtmask2q512:
+    return emitX86SExtMask(*this, ops[0], convertType(expr->getType()),
+                           getLoc(expr->getExprLoc()));
   case X86::BI__builtin_ia32_cvtb2mask128:
   case X86::BI__builtin_ia32_cvtb2mask256:
   case X86::BI__builtin_ia32_cvtb2mask512:
@@ -582,18 +692,22 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned 
builtinID,
   case X86::BI__builtin_ia32_cvtq2mask128:
   case X86::BI__builtin_ia32_cvtq2mask256:
   case X86::BI__builtin_ia32_cvtq2mask512:
+    return emitX86ConvertToMask(*this, ops[0], getLoc(expr->getExprLoc()));
   case X86::BI__builtin_ia32_cvtdq2ps512_mask:
   case X86::BI__builtin_ia32_cvtqq2ps512_mask:
   case X86::BI__builtin_ia32_cvtqq2pd512_mask:
   case X86::BI__builtin_ia32_vcvtw2ph512_mask:
   case X86::BI__builtin_ia32_vcvtdq2ph512_mask:
   case X86::BI__builtin_ia32_vcvtqq2ph512_mask:
+    llvm_unreachable("vcvtw2ph256_round_mask NYI");
   case X86::BI__builtin_ia32_cvtudq2ps512_mask:
   case X86::BI__builtin_ia32_cvtuqq2ps512_mask:
   case X86::BI__builtin_ia32_cvtuqq2pd512_mask:
   case X86::BI__builtin_ia32_vcvtuw2ph512_mask:
   case X86::BI__builtin_ia32_vcvtudq2ph512_mask:
   case X86::BI__builtin_ia32_vcvtuqq2ph512_mask:
+    llvm_unreachable("vcvtuw2ph256_round_mask NYI");
+
   case X86::BI__builtin_ia32_vfmaddsh3_mask:
   case X86::BI__builtin_ia32_vfmaddss3_mask:
   case X86::BI__builtin_ia32_vfmaddsd3_mask:
diff --git a/clang/test/CodeGen/X86/avx512vlbw-builtins.c 
b/clang/test/CodeGen/X86/avx512vlbw-builtins.c
index f6f27d9c3da3d..a088efa6784db 100644
--- a/clang/test/CodeGen/X86/avx512vlbw-builtins.c
+++ b/clang/test/CodeGen/X86/avx512vlbw-builtins.c
@@ -3226,6 +3226,18 @@ __m256i test_mm256_movm_epi8(__mmask32 __A) {
   return _mm256_movm_epi8(__A); 
 }
 
+__m512i test_mm512_movm_epi8(__mmask64 __A) {
+  // CIR-LABEL: _mm512_movm_epi8
+  // CIR: %{{.*}} = cir.cast bitcast %{{.*}} : !u64i -> 
!cir.vector<!cir.int<s, 1> x 64>
+  // CIR: %{{.*}} = cir.cast integral %{{.*}} : !cir.vector<!cir.int<s, 1> x 
64> -> !cir.vector<{{!s8i|!u8i}} x 64>
+
+  // LLVM-LABEL: @test_mm512_movm_epi8
+  // LLVM:  %{{.*}} = bitcast i64 %{{.*}} to <64 x i1>
+  // LLVM:  %{{.*}} = sext <64 x i1> %{{.*}} to <64 x i8>
+  return _mm512_movm_epi8(__A); 
+}
+
+
 __m128i test_mm_movm_epi16(__mmask8 __A) {
   // CHECK-LABEL: test_mm_movm_epi16
   // CHECK: %{{.*}} = bitcast i8 %{{.*}} to <8 x i1>

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to