https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/79494
>From b8fb65dd1e65c36cfb2104e5f35179faa6011552 Mon Sep 17 00:00:00 2001 From: Diego Caballero <diegocaball...@google.com> Date: Thu, 25 Jan 2024 02:39:14 +0000 Subject: [PATCH] [mlir][Vector] Add patterns for efficient i4 -> i8 conversion emulation This PR adds new patterns to improve the generated vector code for the emulation of any conversion that have to go through an i4 -> i8 type extension (only signed extensions are supported for now). This will impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4 -> f8/f16/f32/f64. The asm code generated for the supported cases is significantly better after this PR for both x86 and aarch64. --- .../Transforms/VectorEmulateNarrowType.cpp | 176 ++++++++++++++++-- .../Vector/vector-rewrite-narrow-types.mlir | 33 ++++ 2 files changed, 189 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index a4a72754ccc250..8abd34fd246224 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -642,9 +642,9 @@ struct BitCastRewriter { BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType); - /// Verify that the preconditions for the rewrite are met. - LogicalResult precondition(PatternRewriter &rewriter, - VectorType preconditionVectorType, Operation *op); + /// Verify that general preconditions for the rewrite are met. + LogicalResult commonPrecondition(PatternRewriter &rewriter, + VectorType preconditionType, Operation *op); /// Precompute the metadata for the rewrite. SmallVector<BitCastRewriter::Metadata> @@ -652,9 +652,9 @@ struct BitCastRewriter { /// Rewrite one step of the sequence: /// `(shuffle -> and -> shiftright -> shiftleft -> or)`. - Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue, - Value runningResult, - const BitCastRewriter::Metadata &metadata); + Value genericRewriteStep(PatternRewriter &rewriter, Location loc, + Value initialValue, Value runningResult, + const BitCastRewriter::Metadata &metadata); private: /// Underlying enumerator that encodes the provenance of the bits in the each @@ -719,21 +719,57 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType, LDBG("\n" << enumerator.sourceElementRanges); } -LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter, - VectorType precondition, - Operation *op) { - if (precondition.getRank() != 1 || precondition.isScalable()) +/// Verify that the precondition type meets the common preconditions for any +/// conversion. +static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, + VectorType preconditionType, + Operation *op) { + if (!preconditionType || preconditionType.getRank() != 1 || + preconditionType.isScalable()) return rewriter.notifyMatchFailure(op, "scalable or >1-D vector"); // TODO: consider relaxing this restriction in the future if we find ways // to really work with subbyte elements across the MLIR/LLVM boundary. - int64_t resultBitwidth = precondition.getElementTypeBitWidth(); + unsigned resultBitwidth = preconditionType.getElementTypeBitWidth(); if (resultBitwidth % 8 != 0) return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8"); return success(); } +LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter, + VectorType preconditionType, + Operation *op) { + if (!enumerator.sourceVectorType || !enumerator.targetVectorType) + return rewriter.notifyMatchFailure(op, "types are not vector"); + + return commonConversionPrecondition(rewriter, preconditionType, op); +} + +/// Verify that source and destination element types meet the precondition for +/// the supported aligned conversion cases. Alignment means that the either the +/// source element type is multiple of the destination element type or the other +/// way around. +/// +/// NOTE: This method assumes that common conversion preconditions are met. +static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, + VectorType srcType, + VectorType dstType, + Operation *op) { + if (!srcType || !dstType) + return rewriter.notifyMatchFailure(op, "Not a supported aligned case"); + unsigned srcElemBitwidth = srcType.getElementTypeBitWidth(); + unsigned dstElemBitwidth = dstType.getElementTypeBitWidth(); + unsigned byteBitwidth = 8; + + // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now. + if (srcElemBitwidth != 4 || dstElemBitwidth < 8 || + (dstElemBitwidth % srcElemBitwidth) != 0) + return rewriter.notifyMatchFailure(op, "Not a supported aligned case"); + + return success(); +} + SmallVector<BitCastRewriter::Metadata> BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) { SmallVector<BitCastRewriter::Metadata> result; @@ -775,9 +811,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) { return result; } -Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc, - Value initialValue, Value runningResult, - const BitCastRewriter::Metadata &metadata) { +Value BitCastRewriter::genericRewriteStep( + PatternRewriter &rewriter, Location loc, Value initialValue, + Value runningResult, const BitCastRewriter::Metadata &metadata) { // Create vector.shuffle from the metadata. auto shuffleOp = rewriter.create<vector::ShuffleOp>( loc, initialValue, initialValue, metadata.shuffles); @@ -810,6 +846,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc, return runningResult; } +/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and +/// bitwise ops that take advantage of high-level information to avoid leaving +/// LLVM to scramble with peephole optimizations. +static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, + Value srcValue) { + VectorType srcVecType = cast<VectorType>(srcValue.getType()); + assert(srcVecType.getElementType().isSignlessInteger(4) && + "Expected i4 type"); + + // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>. + int64_t vecDimSize = srcVecType.getShape().back(); + SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape()); + constexpr int64_t i4Toi8BitwidthFactor = 2; + i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor; + auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type()); + Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue); + + // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each + // byte are place in one vector and the high i4 elements in another vector. + constexpr int8_t bitsToShift = 4; + auto shiftValues = rewriter.create<arith::ConstantOp>( + loc, DenseElementsAttr::get(i8VecType, bitsToShift)); + Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues); + Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues); + Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues); + + // 3. Interleave low and high i8 elements using a shuffle. + SmallVector<int64_t> interleaveMaskValues; + interleaveMaskValues.reserve(vecDimSize); + for (int i = 0, end = vecDimSize / 2; i < end; ++i) { + interleaveMaskValues.push_back(i); + interleaveMaskValues.push_back(i + (vecDimSize / 2)); + } + + return rewriter.create<vector::ShuffleOp>( + loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues)); +} + namespace { /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take /// advantage of high-level information to avoid leaving LLVM to scramble with @@ -829,7 +903,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> { VectorType sourceVectorType = bitCastOp.getSourceVectorType(); VectorType targetVectorType = bitCastOp.getResultVectorType(); BitCastRewriter bcr(sourceVectorType, targetVectorType); - if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp))) + if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp))) return failure(); // Perform the rewrite. @@ -839,8 +913,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> { Value runningResult; for (const BitCastRewriter ::Metadata &metadata : bcr.precomputeMetadata(shuffledElementType)) { - runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue, - runningResult, metadata); + runningResult = bcr.genericRewriteStep( + rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata); } // Finalize the rewrite. @@ -893,7 +967,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> { VectorType sourceVectorType = bitCastOp.getSourceVectorType(); VectorType targetVectorType = bitCastOp.getResultVectorType(); BitCastRewriter bcr(sourceVectorType, targetVectorType); - if (failed(bcr.precondition( + if (failed(bcr.commonPrecondition( rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp))) return failure(); @@ -904,8 +978,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> { cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType())); for (const BitCastRewriter::Metadata &metadata : bcr.precomputeMetadata(shuffledElementType)) { - runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), - sourceValue, runningResult, metadata); + runningResult = bcr.genericRewriteStep( + rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata); } // Finalize the rewrite. @@ -923,6 +997,62 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> { return success(); } }; + +/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and +/// bitwise ops that take advantage of high-level information to avoid leaving +/// LLVM to scramble with peephole optimizations. +/// +/// For example: +/// arith.extsi %in : vector<8xi4> to vector<8xi32> +/// is rewriten as +/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> +/// %1 = arith.shli %0, 4 : vector<4xi8> +/// %2 = arith.shrsi %1, 4 : vector<4xi8> +/// %3 = arith.shrsi %0, 4 : vector<4xi8> +/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7] +/// : vector<4xi8>, vector<4xi8> +/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32> +/// +/// arith.sitofp %in : vector<8xi4> to vector<8xf32> +/// is rewriten as +/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> +/// %1 = arith.shli %0, 4 : vector<4xi8> +/// %2 = arith.shrsi %1, 4 : vector<4xi8> +/// %3 = arith.shrsi %0, 4 : vector<4xi8> +/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7] +/// : vector<4xi8>, vector<4xi8> +/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32> +/// +template <typename ConversionOpType> +struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> { + using OpRewritePattern<ConversionOpType>::OpRewritePattern; + + LogicalResult matchAndRewrite(ConversionOpType conversionOp, + PatternRewriter &rewriter) const override { + // Set up the BitCastRewriter and verify the preconditions. + Value srcValue = conversionOp.getIn(); + auto srcVecType = dyn_cast<VectorType>(srcValue.getType()); + auto dstVecType = dyn_cast<VectorType>(conversionOp.getType()); + if (failed( + commonConversionPrecondition(rewriter, dstVecType, conversionOp))) + return failure(); + + // Check general alignment preconditions. + if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType, + conversionOp))) + return failure(); + + // Perform the rewrite. + Value subByteExt = + rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue); + + // Finalize the rewrite. + rewriter.replaceOpWithNewOp<ConversionOpType>( + conversionOp, conversionOp.getType(), subByteExt); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -944,4 +1074,10 @@ void vector::populateVectorNarrowTypeRewritePatterns( patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>, RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(), benefit); + + // Patterns for aligned cases. We set higher priority as they are expected to + // generate better performance for aligned cases. + patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>, + RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>( + patterns.getContext(), benefit.getBenefit() + 1); } diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir index a600fa955b1700..c4fbb4c219b917 100644 --- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir +++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir @@ -193,6 +193,39 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> { return %1 : vector<8xi17> } +// CHECK-LABEL: func.func @aligned_extsi( +func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> { + // CHECK: arith.shli + // CHECK: arith.shrsi + // CHECK: arith.shrsi + // CHECK: vector.shuffle + // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32> + %0 = arith.extsi %a : vector<8xi4> to vector<8xi32> + return %0 : vector<8xi32> +} + +// CHECK-LABEL: func.func @aligned_extsi_base_case( +func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> { + // CHECK: arith.shli + // CHECK: arith.shrsi + // CHECK: arith.shrsi + // CHECK: vector.shuffle + // CHECK-NOT: arith.extsi + %0 = arith.extsi %a : vector<8xi4> to vector<8xi8> + return %0 : vector<8xi8> +} + +// CHECK-LABEL: func.func @aligned_sitofp( +func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> { + // CHECK: arith.shli + // CHECK: arith.shrsi + // CHECK: arith.shrsi + // CHECK: shuffle + // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32> + %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32> + return %0 : vector<8xf32> +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { %f = transform.structured.match ops{["func.func"]} in %module_op _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits