================
@@ -1117,6 +1118,122 @@ mlir::LogicalResult 
CIRToLLVMBinOpLowering::matchAndRewrite(
   return mlir::LogicalResult::success();
 }
 
+mlir::LogicalResult CIRToLLVMBinOpOverflowOpLowering::matchAndRewrite(
+    cir::BinOpOverflowOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  mlir::Location loc = op.getLoc();
+  BinOpOverflowKind arithKind = op.getKind();
+  IntType operandTy = op.getLhs().getType();
+  IntType resultTy = op.getResult().getType();
+
+  EncompassedTypeInfo encompassedTyInfo =
+      computeEncompassedTypeWidth(operandTy, resultTy);
+  mlir::IntegerType encompassedLLVMTy =
+      rewriter.getIntegerType(encompassedTyInfo.width);
+
+  mlir::Value lhs = adaptor.getLhs();
+  mlir::Value rhs = adaptor.getRhs();
+  if (operandTy.getWidth() < encompassedTyInfo.width) {
+    if (operandTy.isSigned()) {
+      lhs = rewriter.create<mlir::LLVM::SExtOp>(loc, encompassedLLVMTy, lhs);
+      rhs = rewriter.create<mlir::LLVM::SExtOp>(loc, encompassedLLVMTy, rhs);
+    } else {
+      lhs = rewriter.create<mlir::LLVM::ZExtOp>(loc, encompassedLLVMTy, lhs);
+      rhs = rewriter.create<mlir::LLVM::ZExtOp>(loc, encompassedLLVMTy, rhs);
+    }
+  }
+
+  std::string intrinName = getLLVMIntrinName(arithKind, encompassedTyInfo.sign,
+                                             encompassedTyInfo.width);
+  auto intrinNameAttr = mlir::StringAttr::get(op.getContext(), intrinName);
+
+  mlir::IntegerType overflowLLVMTy = rewriter.getI1Type();
+  auto intrinRetTy = mlir::LLVM::LLVMStructType::getLiteral(
+      rewriter.getContext(), {encompassedLLVMTy, overflowLLVMTy});
+
+  auto callLLVMIntrinOp = rewriter.create<mlir::LLVM::CallIntrinsicOp>(
+      loc, intrinRetTy, intrinNameAttr, mlir::ValueRange{lhs, rhs});
+  mlir::Value intrinRet = callLLVMIntrinOp.getResult(0);
+
+  mlir::Value result = rewriter
+                           .create<mlir::LLVM::ExtractValueOp>(
+                               loc, intrinRet, ArrayRef<int64_t>{0})
+                           .getResult();
+  mlir::Value overflow = rewriter
+                             .create<mlir::LLVM::ExtractValueOp>(
+                                 loc, intrinRet, ArrayRef<int64_t>{1})
+                             .getResult();
+
+  if (resultTy.getWidth() < encompassedTyInfo.width) {
+    mlir::Type resultLLVMTy = getTypeConverter()->convertType(resultTy);
+    auto truncResult =
+        rewriter.create<mlir::LLVM::TruncOp>(loc, resultLLVMTy, result);
+
+    // Extend the truncated result back to the encompassing type to check for
+    // any overflows during the truncation.
+    mlir::Value truncResultExt;
+    if (resultTy.isSigned())
+      truncResultExt = rewriter.create<mlir::LLVM::SExtOp>(
+          loc, encompassedLLVMTy, truncResult);
+    else
+      truncResultExt = rewriter.create<mlir::LLVM::ZExtOp>(
+          loc, encompassedLLVMTy, truncResult);
+    auto truncOverflow = rewriter.create<mlir::LLVM::ICmpOp>(
+        loc, mlir::LLVM::ICmpPredicate::ne, truncResultExt, result);
+
+    result = truncResult;
+    overflow = rewriter.create<mlir::LLVM::OrOp>(loc, overflow, truncOverflow);
+  }
+
+  mlir::Type boolLLVMTy =
+      getTypeConverter()->convertType(op.getOverflow().getType());
+  if (boolLLVMTy != rewriter.getI1Type())
+    overflow = rewriter.create<mlir::LLVM::ZExtOp>(loc, boolLLVMTy, overflow);
+
+  rewriter.replaceOp(op, mlir::ValueRange{result, overflow});
+
+  return mlir::success();
+}
+
+std::string CIRToLLVMBinOpOverflowOpLowering::getLLVMIntrinName(
+    cir::BinOpOverflowKind opKind, bool isSigned, unsigned width) {
+  // The intrinsic name is `@llvm.{s|u}{opKind}.with.overflow.i{width}`
+
+  std::string name = "llvm.";
----------------
mmha wrote:

I wonder if this should use `llvm::raw_string_ostream` or `Twine`.

https://github.com/llvm/llvm-project/pull/133118
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to