================ @@ -1117,6 +1118,122 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite( return mlir::LogicalResult::success(); } +mlir::LogicalResult CIRToLLVMBinOpOverflowOpLowering::matchAndRewrite( + cir::BinOpOverflowOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Location loc = op.getLoc(); + BinOpOverflowKind arithKind = op.getKind(); + IntType operandTy = op.getLhs().getType(); + IntType resultTy = op.getResult().getType(); + + EncompassedTypeInfo encompassedTyInfo = + computeEncompassedTypeWidth(operandTy, resultTy); + mlir::IntegerType encompassedLLVMTy = + rewriter.getIntegerType(encompassedTyInfo.width); + + mlir::Value lhs = adaptor.getLhs(); + mlir::Value rhs = adaptor.getRhs(); + if (operandTy.getWidth() < encompassedTyInfo.width) { + if (operandTy.isSigned()) { + lhs = rewriter.create<mlir::LLVM::SExtOp>(loc, encompassedLLVMTy, lhs); + rhs = rewriter.create<mlir::LLVM::SExtOp>(loc, encompassedLLVMTy, rhs); + } else { + lhs = rewriter.create<mlir::LLVM::ZExtOp>(loc, encompassedLLVMTy, lhs); + rhs = rewriter.create<mlir::LLVM::ZExtOp>(loc, encompassedLLVMTy, rhs); + } + } + + std::string intrinName = getLLVMIntrinName(arithKind, encompassedTyInfo.sign, + encompassedTyInfo.width); + auto intrinNameAttr = mlir::StringAttr::get(op.getContext(), intrinName); + + mlir::IntegerType overflowLLVMTy = rewriter.getI1Type(); + auto intrinRetTy = mlir::LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), {encompassedLLVMTy, overflowLLVMTy}); + + auto callLLVMIntrinOp = rewriter.create<mlir::LLVM::CallIntrinsicOp>( + loc, intrinRetTy, intrinNameAttr, mlir::ValueRange{lhs, rhs}); + mlir::Value intrinRet = callLLVMIntrinOp.getResult(0); + + mlir::Value result = rewriter + .create<mlir::LLVM::ExtractValueOp>( + loc, intrinRet, ArrayRef<int64_t>{0}) + .getResult(); + mlir::Value overflow = rewriter + .create<mlir::LLVM::ExtractValueOp>( + loc, intrinRet, ArrayRef<int64_t>{1}) + .getResult(); + + if (resultTy.getWidth() < encompassedTyInfo.width) { + mlir::Type resultLLVMTy = getTypeConverter()->convertType(resultTy); + auto truncResult = + rewriter.create<mlir::LLVM::TruncOp>(loc, resultLLVMTy, result); + + // Extend the truncated result back to the encompassing type to check for + // any overflows during the truncation. + mlir::Value truncResultExt; + if (resultTy.isSigned()) + truncResultExt = rewriter.create<mlir::LLVM::SExtOp>( + loc, encompassedLLVMTy, truncResult); + else + truncResultExt = rewriter.create<mlir::LLVM::ZExtOp>( + loc, encompassedLLVMTy, truncResult); + auto truncOverflow = rewriter.create<mlir::LLVM::ICmpOp>( + loc, mlir::LLVM::ICmpPredicate::ne, truncResultExt, result); + + result = truncResult; + overflow = rewriter.create<mlir::LLVM::OrOp>(loc, overflow, truncOverflow); + } + + mlir::Type boolLLVMTy = + getTypeConverter()->convertType(op.getOverflow().getType()); + if (boolLLVMTy != rewriter.getI1Type()) + overflow = rewriter.create<mlir::LLVM::ZExtOp>(loc, boolLLVMTy, overflow); + + rewriter.replaceOp(op, mlir::ValueRange{result, overflow}); + + return mlir::success(); +} + +std::string CIRToLLVMBinOpOverflowOpLowering::getLLVMIntrinName( + cir::BinOpOverflowKind opKind, bool isSigned, unsigned width) { + // The intrinsic name is `@llvm.{s|u}{opKind}.with.overflow.i{width}` + + std::string name = "llvm."; ---------------- mmha wrote:
I wonder if this should use `llvm::raw_string_ostream` or `Twine`. https://github.com/llvm/llvm-project/pull/133118 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits