https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/80870
>From ed244c7cbd95294077fa603e022ac234aaf19aa2 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Sat, 30 Dec 2023 13:51:52 -0600 Subject: [PATCH 1/5] Implement GroupedConvolutionOpInterface --- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 36 +++++ .../Dialect/Linalg/IR/LinalgInterfaces.td | 51 +++++++ .../Dialect/Linalg/IR/LinalgStructuredOps.td | 127 ++++++++++++++++++ .../Dialect/Linalg/IR/LinalgInterfaces.cpp | 102 ++++++++++++++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 105 +++++++++++++++ mlir/test/Dialect/Linalg/loops.mlir | 8 +- mlir/test/Dialect/Linalg/named-ops.mlir | 11 ++ mlir/test/Dialect/Linalg/tile-conv.mlir | 2 +- 8 files changed, 436 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h index 6c8240267e7d05..e2f24432c003b4 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -28,6 +28,8 @@ namespace mlir { namespace linalg { class IteratorTypeAttr; class LinalgOp; +class ConvolutionOpInterface; +class GroupedConvolutionOpInterface; namespace detail { /// Implementation of the method that check if given operands @@ -115,6 +117,37 @@ bool isaCopyOpInterface(LinalgOp linalgOp); namespace detail { +// Common implementations for ConvolutionOpInterface +namespace convolution_impl { +// Returns strides as a vector. +SmallVector<int64_t, 2> getStrides(ConvolutionOpInterface op); +// Returns dilations as a vector. +SmallVector<int64_t, 2> getDilations(ConvolutionOpInterface op); +// Region builder for basic convolution +void regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef<NamedAttribute> attrs); +// Region builder for basic quantized convolution +void quantizedRegionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef<NamedAttribute> attrs); +void getEffects( + Operation *op, + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects); +ParseResult parse(OpAsmParser &parser, OperationState &result, + bool isQuantized = false); +void print(LinalgOp op, OpAsmPrinter &p); +} // namespace convolution_impl + +// Common implementations for GroupedConvolutionOpInterface +namespace grouped_convolution_impl { +int64_t getSpatialRank(GroupedConvolutionOpInterface op); +ArrayAttr createCommonIndexingMaps(MLIRContext *ctx, int64_t numSpatial, + int64_t channelPos, + const SmallVectorImpl<int64_t> &strides, + const SmallVectorImpl<int64_t> &dilations); +ArrayAttr getIteratorTypes(GroupedConvolutionOpInterface op); +} // namespace grouped_convolution_impl + /// Returns true if the block contains a contraction of the following form: /// /// %0 = <elemwise>(permutation-of(cu(block-argument-0), @@ -171,6 +204,9 @@ LogicalResult verifyContractionInterface(Operation *op); /// Verify that `op` conforms to the ConvolutionOpInterface. LogicalResult verifyConvolutionInterface(Operation *op); +/// Verify that `op` conforms to the GroupedConvolutionOpInterface. +LogicalResult verifyGroupedConvolutionInterface(Operation *op); + /// Verify that `op` conforms to the FillOpInterface. LogicalResult verifyFillInterface(Operation *op); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index fbf3f19cde0e9b..170ebf9d43030f 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -175,6 +175,57 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> { return $_op.getOperation()->getOperand(1); }] >, + InterfaceMethod< + /*desc=*/"Return the spatial rank.", + /*retTy=*/"int64_t", + /*methodName=*/"getSpatialRank", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Most convolution's inputs have batch, channel and spatial dims + return cast<ShapedType>(image().getType()).getRank() - 2; + }] + > + ]; +} + +def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInterface", [ + LinalgConvolutionOpInterface]> { + let description = [{ + A grouped convolution is defined in general terms: + 1. It is a convolution as defined by `ConvolutionOpInterface`. + 2. Operands have a the following distinct dimensions (excluding batch in input/output): group, channel, spatial + 3. `input_rank == kernel_rank == output_rank` (including batch in input/output) + 4. Reductions are along the input channel and spatial dimensions while group, output channel + and output spatial dimensions are parallel. + }]; + let cppNamespace = "::mlir::linalg"; + let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }]; + let methods = [ + InterfaceMethod<[{ + Returns indexing maps for any spatial dimension. + }], + "::mlir::ArrayAttr", "getIteratorTypes", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return detail::grouped_convolution_impl::getIteratorTypes($_op); + }]>, + InterfaceMethod<[{ + Returns strides. + }], + "::llvm::SmallVector<int64_t, 2>", "getStridesVector", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return detail::convolution_impl::getStrides($_op); + }]>, + InterfaceMethod<[{ + Returns dilations. + }], + "::llvm::SmallVector<int64_t, 2>", "getDilationsVector", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return detail::convolution_impl::getDilations($_op); + }]> ]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 751edd02288301..44db786a64595e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -384,6 +384,133 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// GroupedConvNDOp ops. +//===----------------------------------------------------------------------===// + +def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd", + [AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> { + + let summary = [{ + Performs N-D grouped convolution with switchable channel position; either first or last. + }]; + let description = [{ + Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`, + will represent all spatial dimensions. Operand layouts are determined by the `channel_first` + `bool` attritbute. When placing the channel dim first or last, the batch dim is excluded. In + any case, the channel and spatial dims are in the same relative order for all operands. + + Domain: N, G, F, S, C, KS + + Layouts: + `channel_first == true`: + Input: `NGCS` + Kernel: `FS` + Output: `NGFS` + + `channel_first == false`: + Input: `NSGC` + Kernel: `SGFC` + Output: `NSGF` + + }]; + + let arguments = (ins + Variadic<TensorOrMemref>:$inputs, + Variadic<TensorOrMemref>:$inits, + DefaultValuedAttr<BoolAttr, "true">:$channel_first, + OptionalAttr<I64ElementsAttr>:$strides, + OptionalAttr<I64ElementsAttr>:$dilations + ); + let results = (outs Variadic<AnyRankedTensor>:$result_tensors); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder< + (ins "Value":$input, "Value":$filter, "Value":$init, CArg<"bool", "true">:$channel_first, + CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations, + CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), + [{ + $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first)); + int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3; + if (strides.empty()) + strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1); + if (dilations.empty()) + dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1); + $_state.addAttribute(getStridesAttrName($_state.name), + ::mlir::DenseElementsAttr::get( + ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides)); + $_state.addAttribute(getDilationsAttrName($_state.name), + ::mlir::DenseElementsAttr::get( + ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations)); + buildStructuredOp($_builder, $_state, std::nullopt, {input, filter}, init, + attributes, GroupedConvNDOp::getRegionBuilder()); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, + "Value":$init, CArg<"bool", "true">:$channel_first, + CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations, + CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), + [{ + $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first)); + int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3; + if (strides.empty()) + strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1); + if (dilations.empty()) + dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1); + $_state.addAttribute(getStridesAttrName($_state.name), + ::mlir::DenseElementsAttr::get( + ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides)); + $_state.addAttribute(getDilationsAttrName($_state.name), + ::mlir::DenseElementsAttr::get( + ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations)); + buildStructuredOp($_builder, $_state, resultTensorTypes, + {input, filter}, init, attributes, GroupedConvNDOp::getRegionBuilder()); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, + "Value":$init, "Attribute":$channel_first, "Attribute":$strides, "Attribute":$dilations, + CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), + [{ + $_state.addAttribute(getChannelFirstAttrName($_state.name), channel_first); + $_state.addAttribute(getStridesAttrName($_state.name), strides); + $_state.addAttribute(getDilationsAttrName($_state.name), dilations); + buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init, + attributes, GroupedConvNDOp::getRegionBuilder()); + }]> + ]; + + // TODO: Figure out how to move this to the interface + let extraClassDeclaration = structuredOpsBaseDecls # [{ + void print(::mlir::OpAsmPrinter &printer) { + return detail::convolution_impl::print(*this, printer); + } + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result) { + return detail::convolution_impl::parse(parser, result); + } + static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &, + mlir::ArrayRef<mlir::NamedAttribute>)> + getRegionBuilder() { + return detail::convolution_impl::regionBuilder; + } + // Implement functions necessary for DestinationStyleOpInterface. + MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); } + + // Implement functions necessary for LinalgOp. + ArrayAttr getIndexingMaps(); + + // Implement functions necessary for GroupedConvolutionOpInterface + int64_t getSpatialRank() { + return detail::grouped_convolution_impl::getSpatialRank(*this); + } + + int64_t getChannelPosition() { + return (getChannelFirstAttr().getValue()) ? 1 : getSpatialRank() + 1; + } + }]; +} //===----------------------------------------------------------------------===// // Transpose op. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index ba419d32f22a3e..ba28c4ed954970 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -638,6 +638,108 @@ enum class MatchConvolutionResult { }; } // namespace mlir::linalg::detail +SmallVector<int64_t, 2> +mlir::linalg::detail::convolution_impl::getStrides(ConvolutionOpInterface op) { + auto maybeStridesAttr = op->getAttrOfType<DenseIntElementsAttr>("strides"); + if (!maybeStridesAttr) { + OpBuilder builder(op.getContext()); + return SmallVector<int64_t, 2>(op.getSpatialRank(), 1); + } + return llvm::to_vector(maybeStridesAttr.getValues<int64_t>()); +} + +SmallVector<int64_t, 2> mlir::linalg::detail::convolution_impl::getDilations( + ConvolutionOpInterface op) { + auto maybeDilationsAttr = + op->getAttrOfType<DenseIntElementsAttr>("dilations"); + if (!maybeDilationsAttr) { + OpBuilder builder(op.getContext()); + return SmallVector<int64_t, 2>(op.getSpatialRank(), 1); + } + return llvm::to_vector(maybeDilationsAttr.getValues<int64_t>()); +} + +int64_t mlir::linalg::detail::grouped_convolution_impl::getSpatialRank( + GroupedConvolutionOpInterface op) { + return cast<ShapedType>(op.image().getType()).getRank() - 3; +} + +ArrayAttr mlir::linalg::detail::grouped_convolution_impl::getIteratorTypes( + GroupedConvolutionOpInterface op) { + int64_t numSpatialDims = op.getSpatialRank(); + SmallVector<Attribute> iteratorTypes( + 3 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par)); + SmallVector<Attribute> reductions( + numSpatialDims + 1, IteratorTypeAttr::get(op.getContext(), red)); + iteratorTypes.insert(iteratorTypes.end(), reductions.begin(), + reductions.end()); + + return Builder(op.getContext()).getArrayAttr(iteratorTypes); +} + +ArrayAttr +mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps( + MLIRContext *ctx, int64_t numSpatial, int64_t channelPos, + const SmallVectorImpl<int64_t> &strides, + const SmallVectorImpl<int64_t> &dilations) { + + // Domain: (n, g, f, os, c, ks) + AffineExpr n = getAffineDimExpr(0, ctx); + AffineExpr g = getAffineDimExpr(1, ctx); + AffineExpr f = getAffineDimExpr(2, ctx); + SmallVector<AffineExpr> s( + llvm::map_range(llvm::seq<int64_t>(3, numSpatial + 3), + [&](int64_t d) { return getAffineDimExpr(d, ctx); })); + AffineExpr c = getAffineDimExpr(numSpatial + 3, ctx); + SmallVector<AffineExpr> ks(llvm::map_range( + llvm::seq<int64_t>(numSpatial + 4, 2 * (numSpatial + 1) + 2), + [&](int64_t d) { return getAffineDimExpr(d, ctx); })); + + // Initialze operand accesses in nw order and insert c according to channel + // position + SmallVector<AffineExpr> inExprs = {n}, outExprs = {n}; + SmallVector<AffineExpr> gc = {g, c}; + SmallVector<AffineExpr> gf = {g, f}; + SmallVector<AffineExpr> gfc = {g, f, c}; + for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, strides, dilations)) { + inExprs.push_back(sp * st + ksp * di); + outExprs.push_back(sp); + } + SmallVector<AffineExpr> kExprs(ks); + inExprs.insert(inExprs.begin() + channelPos, gc.begin(), gc.end()); + kExprs.insert(channelPos == 0 ? kExprs.begin() + : kExprs.begin() + channelPos - 1, + gfc.begin(), gfc.end()); + outExprs.insert(outExprs.begin() + channelPos, gf.begin(), gf.end()); + SmallVector<AffineMap> maps( + {AffineMap::get(4 + 2 * numSpatial, 0, inExprs, ctx), + AffineMap::get(4 + 2 * numSpatial, 0, kExprs, ctx), + AffineMap::get(4 + 2 * numSpatial, 0, outExprs, ctx)}); + + return Builder(ctx).getAffineMapArrayAttr(maps); +} + +LogicalResult +mlir::linalg::detail::verifyGroupedConvolutionInterface(Operation *op) { + if (failed(verifyConvolutionInterface(op))) + return failure(); + if (GroupedConvolutionOpInterface conv = + dyn_cast<GroupedConvolutionOpInterface>(op)) { + const auto imageType = conv.image().getType().dyn_cast<ShapedType>(); + const auto imageRank = imageType.getRank(); + const auto kernelRank = + conv.filter().getType().cast<ShapedType>().getRank(); + const auto initType = + cast<LinalgOp>(op).getDpsInits()[0].getType().dyn_cast<ShapedType>(); + const auto initRank = initType.getRank(); + if (imageRank != kernelRank || imageRank != initRank) + return op->emitError( + "Rank relationship must be `in_rank == out_rank == kernel_rank`"); + return success(); + } + return failure(); +} + mlir::linalg::detail::MatchConvolutionResult mlir::linalg::detail::isConvolutionInterfaceImpl( Operation *op, ConvolutionDimensions *dimensions) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index b68aa77fd83a1c..c31b17082b8900 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1663,6 +1663,111 @@ LogicalResult ReduceOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ConvolutionOpInterface +//===----------------------------------------------------------------------===// + +// There must be a way to avoid defining the following 3 functions +ParseResult mlir::linalg::detail::convolution_impl::parse( + OpAsmParser &parser, OperationState &result, bool isQuantized) { + if (isQuantized) + return parseNamedStructuredOp( + parser, result, 5, + mlir::linalg::detail::convolution_impl::quantizedRegionBuilder); + return parseNamedStructuredOp( + parser, result, 3, mlir::linalg::detail::convolution_impl::regionBuilder); +} + +void mlir::linalg::detail::convolution_impl::print(LinalgOp op, + OpAsmPrinter &p) { + printNamedStructuredOp(p, op.getOperation(), op.getDpsInputs(), + op.getDpsInits()); +} + +// Build {mul, add} region for convolution +void mlir::linalg::detail::convolution_impl::regionBuilder( + ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) { + assert(block.getNumArguments() == 3 && + "ConvolutionInterface regionBuilder expects 3 (>=0) args"); + RegionBuilderHelper helper(block.getArgument(0).getContext(), block); + SmallVector<Value> yields; + + Value value1 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), + block.getArgument(0)); + Value value2 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), + block.getArgument(1)); + Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); + Value value4 = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); + yields.push_back(value4); + helper.yieldOutputs(yields); +} + +void mlir::linalg::detail::convolution_impl::quantizedRegionBuilder( + ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) { + assert(block.getNumArguments() == 5 && + "ConvolutionInterface regionBuilder expects 5 args"); + RegionBuilderHelper helper(block.getArgument(0).getContext(), block); + Value value1 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(), + block.getArgument(0)); + Value value2 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(), + block.getArgument(2)); + Value value3 = helper.buildBinaryFn(BinaryFn::sub, value1, value2); + Value value4 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(), + block.getArgument(1)); + Value value5 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(), + block.getArgument(3)); + Value value6 = helper.buildBinaryFn(BinaryFn::sub, value4, value5); + Value value7 = helper.buildBinaryFn(BinaryFn::mul, value3, value6); + Value value8 = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(4), value7); + helper.yieldOutputs({value8}); +} + +void mlir::linalg::detail::convolution_impl::getEffects( + Operation *op, + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + if (!isa<ConvolutionOpInterface>(op)) + return; + if (LinalgOp linalgOp = dyn_cast<LinalgOp>(op)) { + if (linalgOp.hasTensorSemantics()) + return; + getGenericEffectsImpl(effects, linalgOp.getOperation()->getResults(), + linalgOp.getDpsInputs(), linalgOp.getDpsInits()); + } +} + +//===----------------------------------------------------------------------===// +// GroupedConvNDOp +//===----------------------------------------------------------------------===// + +void GroupedConvNDOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + return detail::convolution_impl::getEffects(*this, effects); +} + +ArrayAttr GroupedConvNDOp::getIndexingMaps() { + ArrayAttr cached = (*this)->getAttrOfType<ArrayAttr>( + LinalgDialect::kMemoizedIndexingMapsAttrName); + if (cached) + return cached; + + cached = detail::grouped_convolution_impl::createCommonIndexingMaps( + getContext(), getSpatialRank(), getChannelPosition(), getStridesVector(), + getDilationsVector()); + + (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached); + return cached; +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir index 8c13422fd63833..640680483130d7 100644 --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -1,12 +1,10 @@ -// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s -// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s +// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefixes=COMMON,CHECK +// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefixes=COMMON,CHECKPARALLEL %s // Test that we can lower all the way to LLVM without crashing, don't check results here. // RUN: mlir-opt %s -convert-linalg-to-loops -test-lower-to-llvm -o=/dev/null 2>&1 -// CHECK: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> - -// CHECKPARALLEL: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// COMMON: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> func.func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) { %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 29977a71dbb864..bd728edd1ec715 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1,5 +1,16 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s +// ----- + +// CHECK-LABEL: func @gen_grouped_1D_channel_first_memref +func.func @gen_grouped_1D_channel_first_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) { + // CHECK: grouped_conv_nd {{.*}}channel_first = true + linalg.grouped_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>) + return +} + +// ----- + // CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> { %zero = arith.constant 0.000000e+00 : f32 diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir index 4a940f12662e6c..1662f5c45fe804 100644 --- a/mlir/test/Dialect/Linalg/tile-conv.mlir +++ b/mlir/test/Dialect/Linalg/tile-conv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -transform-interpreter -canonicalize | FileCheck %s +// RUN: mlir-opt %s -transform-interpreter -canonicalize -split-input-file | FileCheck %s // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)> >From dbd721546e2f2811f4ed6a118412c4cd198530d6 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Sun, 31 Dec 2023 09:58:37 -0600 Subject: [PATCH 2/5] Add bufferization test --- mlir/test/Dialect/Linalg/bufferize.mlir | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index 29f27e6838e661..9d3444fe2ce9cb 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -217,3 +217,25 @@ func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> { %3 = func.call @csum(%2) : (tensor<6xi64>) -> tensor<6xi64> return %3 : tensor<6xi64> } + + + +// ----- + +// CHECK-LABEL: func @gen_grouped_3D_channel_first_tensor( +// CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<64x2x16x26x26x26xf32>, +// CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<2x20x16x3x3x3xf32>, +// CHECK-SAME: %[[ARG2_TENSOR:.*]]: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> { +// CHECK-DAG: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<64x2x16x26x26x26xf32> +// CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<2x20x16x3x3x3xf32> +// CHECK-DAG: %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x2x20x8x8x8xf32> +// CHECK-DAG: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x2x20x8x8x8xf32> +// CHECK: memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32> to memref<64x2x20x8x8x8xf32> +// CHECK: linalg.grouped_conv_nd +// CHECK-SAME: {channel_first = true, dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>} +// CHECK-SAME: ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x2x16x26x26x26xf32>, memref<2x20x16x3x3x3xf32>) +// CHECK-SAME: outs(%[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32>) +func.func @gen_grouped_3D_channel_first_tensor(%arg0: tensor<64x2x16x26x26x26xf32>, %arg1: tensor<2x20x16x3x3x3xf32>, %arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> { + %0 = linalg.grouped_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> + return %0 : tensor<64x2x20x8x8x8xf32> +} >From 162904694b84924261c6ddd1f9f4430df66f6a10 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Sun, 31 Dec 2023 10:46:45 -0600 Subject: [PATCH 3/5] Add tiling regression test --- mlir/test/Dialect/Linalg/tile-conv.mlir | 60 +++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir index 1662f5c45fe804..50065ccb18cb84 100644 --- a/mlir/test/Dialect/Linalg/tile-conv.mlir +++ b/mlir/test/Dialect/Linalg/tile-conv.mlir @@ -41,3 +41,63 @@ module attributes {transform.with_named_sequence} { // CHECK: linalg.conv_2d // CHECK-SAME: ins(%[[SVIN]], %[[SVKER]] // CHECK-SAME: outs(%[[SVOUT]] + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 6)> +// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0)[s0] -> (d0 + s0 - 1)> + +func.func @grouped_conv_2D(%arg0 : memref<?x?x?x?x?xf32>, %arg1 : memref<?x?x?x?x?xf32>, %arg2 : memref<?x?x?x?x?xf32>) { + linalg.grouped_conv_nd ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.grouped_conv_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop:5 = transform.structured.tile_using_for %0 [2, 3, 4, 5, 6] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK: func @grouped_conv_2D +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?x?xf32> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[BATCH:.*]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[GROUPS:.*]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[IN_CHANNELS:.*]] = memref.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[OUT_CHANNELS:.*]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[KW:.*]] = memref.dim %[[ARG1]], %[[C3]] +// CHECK-DAG: %[[KH:.*]] = memref.dim %[[ARG1]], %[[C4]] +// CHECK-DAG: %[[W:.*]] = memref.dim %[[ARG2]], %[[C3]] +// CHECK-DAG: %[[H:.*]] = memref.dim %[[ARG2]], %[[C4]] +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[BATCH]] step %[[C2]] +// CHECK: %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[BATCH]]] +// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[GROUPS]] step %[[C3]] +// CHECK: %[[T5:.*]] = affine.min #[[MAP1]](%[[J]])[%[[GROUPS]]] +// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[OUT_CHANNELS]] step %[[C4]] +// CHECK-DAG: %[[T6:.*]] = affine.min #[[MAP2]](%[[K]])[%[[OUT_CHANNELS]]] +// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[W]] step %[[C5]] +// CHECK-DAG: %[[T7:.*]] = affine.min #[[MAP3]](%[[L]])[%[[W]]] +// CHECK: scf.for %[[M:.*]] = %[[C0]] to %[[H]] step %[[C6]] +// CHECK-DAG: %[[T8:.*]] = affine.min #[[MAP4]](%[[M]])[%[[H]]] +// CHECK-DAG: %[[T9:.*]] = affine.apply #[[MAP5]](%[[T7]])[%[[KW]]] +// CHECK-DAG: %[[T10:.*]] = affine.apply #[[MAP5]](%[[T8]])[%[[KH]]] +// CHECK-DAG: %[[SVIN:.*]] = memref.subview %[[ARG0]][%[[I]], %[[J]], 0, %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[IN_CHANNELS]], %[[T9]], %[[T10]]] +// CHECK-DAG: %[[SVKER:.*]] = memref.subview %[[ARG1]][%[[J]], %[[K]], 0, 0, 0] [%[[T5]], %[[T6]], %[[IN_CHANNELS]], %[[KW]], %[[KH]]] +// CHECK-DAG: %[[SVOUT:.*]] = memref.subview %[[ARG2]][%[[I]], %[[J]], %[[K]], %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[T6]], %[[T7]], %[[T8]]] +// CHECK: linalg.grouped_conv_nd {channel_first = true} +// CHECK-SAME: ins(%[[SVIN]], %[[SVKER]] +// CHECK-SAME: outs(%[[SVOUT]] \ No newline at end of file >From 1969a22b7b5b50dfeaf27a56237fdbaf99947ae2 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Sun, 31 Dec 2023 11:32:46 -0600 Subject: [PATCH 4/5] Add interface methods for getting channel and group sizes --- .../Dialect/Linalg/IR/LinalgInterfaces.td | 29 +++++++++++++++++++ .../Dialect/Linalg/IR/LinalgStructuredOps.td | 2 +- .../Dialect/Linalg/IR/LinalgInterfaces.cpp | 6 ++-- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 170ebf9d43030f..3feadfa17a2e5f 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -202,6 +202,35 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter let cppNamespace = "::mlir::linalg"; let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }]; let methods = [ + InterfaceMethod<[{ + Returns the channel position. + }], + "int64_t", "getChannelPosition", (ins) + >, + InterfaceMethod<[{ + Get number of groups. + }], + "int64_t", "getNumGroups", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition() - 1]; + }]>, + InterfaceMethod<[{ + Get number of input channels. + }], + "int64_t", "getNumInputChannels", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition()]; + }]>, + InterfaceMethod<[{ + Get number of output channels. + }], + "int64_t", "getNumOutputChannels", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getChannelPosition()]; + }]>, InterfaceMethod<[{ Returns indexing maps for any spatial dimension. }], diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 44db786a64595e..bc7e7ba004c9b1 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -507,7 +507,7 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd", } int64_t getChannelPosition() { - return (getChannelFirstAttr().getValue()) ? 1 : getSpatialRank() + 1; + return (getChannelFirstAttr().getValue()) ? 2 : getSpatialRank() + 2; } }]; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index ba28c4ed954970..c736bd064bded2 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -706,11 +706,11 @@ mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps( outExprs.push_back(sp); } SmallVector<AffineExpr> kExprs(ks); - inExprs.insert(inExprs.begin() + channelPos, gc.begin(), gc.end()); + inExprs.insert(inExprs.begin() + channelPos - 1, gc.begin(), gc.end()); kExprs.insert(channelPos == 0 ? kExprs.begin() - : kExprs.begin() + channelPos - 1, + : kExprs.begin() + channelPos - 2, gfc.begin(), gfc.end()); - outExprs.insert(outExprs.begin() + channelPos, gf.begin(), gf.end()); + outExprs.insert(outExprs.begin() + channelPos - 1, gf.begin(), gf.end()); SmallVector<AffineMap> maps( {AffineMap::get(4 + 2 * numSpatial, 0, inExprs, ctx), AffineMap::get(4 + 2 * numSpatial, 0, kExprs, ctx), >From cce8517b88a98137d1e8a4d190c89c697702cd4c Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Mon, 1 Jan 2024 22:49:11 -0600 Subject: [PATCH 5/5] Implement layout attribute to generalize dim positions (WIP) --- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 9 +-- .../Dialect/Linalg/IR/LinalgInterfaces.td | 25 ++++++-- .../Dialect/Linalg/IR/LinalgStructuredOps.td | 62 +++++++++++------- .../mlir/Dialect/Utils/StructuredOpsUtils.td | 12 ++++ .../Dialect/Linalg/IR/LinalgInterfaces.cpp | 63 +++++++++++++------ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +- mlir/test/Dialect/Linalg/bufferize.mlir | 5 +- mlir/test/Dialect/Linalg/named-ops.mlir | 4 +- mlir/test/Dialect/Linalg/tile-conv.mlir | 4 +- 9 files changed, 128 insertions(+), 58 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h index e2f24432c003b4..72f65c9e810c67 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -141,10 +141,11 @@ void print(LinalgOp op, OpAsmPrinter &p); // Common implementations for GroupedConvolutionOpInterface namespace grouped_convolution_impl { int64_t getSpatialRank(GroupedConvolutionOpInterface op); -ArrayAttr createCommonIndexingMaps(MLIRContext *ctx, int64_t numSpatial, - int64_t channelPos, - const SmallVectorImpl<int64_t> &strides, - const SmallVectorImpl<int64_t> &dilations); +ArrayAttr createCommonIndexingMaps( + MLIRContext *ctx, int64_t numSpatial, + const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts, + const SmallVectorImpl<int64_t> &strides, + const SmallVectorImpl<int64_t> &dilations); ArrayAttr getIteratorTypes(GroupedConvolutionOpInterface op); } // namespace grouped_convolution_impl diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 3feadfa17a2e5f..5ae481a222e3c8 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -203,9 +203,24 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }]; let methods = [ InterfaceMethod<[{ - Returns the channel position. + Returns the groups position for the input. }], - "int64_t", "getChannelPosition", (ins) + "SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getLayoutsEnums", (ins) + >, + InterfaceMethod<[{ + Returns the groups position for the input. + }], + "int64_t", "getInputGroupsPosition", (ins) + >, + InterfaceMethod<[{ + Returns the channel position for the input. + }], + "int64_t", "getInputChannelPosition", (ins) + >, + InterfaceMethod<[{ + Returns the channel position for the output. + }], + "int64_t", "getOutputChannelPosition", (ins) >, InterfaceMethod<[{ Get number of groups. @@ -213,7 +228,7 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter "int64_t", "getNumGroups", (ins), /*methodBody=*/[{}], /*defaultImplementation=*/[{ - return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition() - 1]; + return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputGroupsPosition() - 1]; }]>, InterfaceMethod<[{ Get number of input channels. @@ -221,7 +236,7 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter "int64_t", "getNumInputChannels", (ins), /*methodBody=*/[{}], /*defaultImplementation=*/[{ - return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getChannelPosition()]; + return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputChannelPosition()]; }]>, InterfaceMethod<[{ Get number of output channels. @@ -229,7 +244,7 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter "int64_t", "getNumOutputChannels", (ins), /*methodBody=*/[{}], /*defaultImplementation=*/[{ - return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getChannelPosition()]; + return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getOutputChannelPosition()]; }]>, InterfaceMethod<[{ Returns indexing maps for any spatial dimension. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index bc7e7ba004c9b1..fcfd9f61aa75e4 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -396,29 +396,23 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd", }]; let description = [{ Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`, - will represent all spatial dimensions. Operand layouts are determined by the `channel_first` - `bool` attritbute. When placing the channel dim first or last, the batch dim is excluded. In - any case, the channel and spatial dims are in the same relative order for all operands. + will represent all spatial dimensions. Operand layouts are determined by the `layouts` + `StrArrayAttr` attritbute. Each element of the array is a string representing the layout of the + corresponding operand and should be be mappable to a `GroupedConvDim` enum, i.e. one of + n: (batch dim) + g: (group dim) + f: (feature or output channel dim) + s: (all spatial dims) + c: (input channel dim). - Domain: N, G, F, S, C, KS - - Layouts: - `channel_first == true`: - Input: `NGCS` - Kernel: `FS` - Output: `NGFS` - - `channel_first == false`: - Input: `NSGC` - Kernel: `SGFC` - Output: `NSGF` + The domain will always be in the order `(N, G, F, S, C, KS)`. }]; let arguments = (ins Variadic<TensorOrMemref>:$inputs, Variadic<TensorOrMemref>:$inits, - DefaultValuedAttr<BoolAttr, "true">:$channel_first, + DefaultValuedAttr<StrArrayAttr, "{\"ngcs\", \"gfcs\", \"ngfs\"}">:$layouts, OptionalAttr<I64ElementsAttr>:$strides, OptionalAttr<I64ElementsAttr>:$dilations ); @@ -428,11 +422,10 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd", let skipDefaultBuilders = 1; let builders = [ OpBuilder< - (ins "Value":$input, "Value":$filter, "Value":$init, CArg<"bool", "true">:$channel_first, + (ins "Value":$input, "Value":$filter, "Value":$init, CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), [{ - $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first)); int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3; if (strides.empty()) strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1); @@ -449,11 +442,10 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd", }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, - "Value":$init, CArg<"bool", "true">:$channel_first, + "Value":$init, CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), [{ - $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first)); int64_t numSpatialDims = input.getType().cast<ShapedType>().getRank() - 3; if (strides.empty()) strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1); @@ -470,10 +462,9 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd", }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, - "Value":$init, "Attribute":$channel_first, "Attribute":$strides, "Attribute":$dilations, + "Value":$init, "Attribute":$strides, "Attribute":$dilations, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), [{ - $_state.addAttribute(getChannelFirstAttrName($_state.name), channel_first); $_state.addAttribute(getStridesAttrName($_state.name), strides); $_state.addAttribute(getDilationsAttrName($_state.name), dilations); buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init, @@ -506,8 +497,31 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd", return detail::grouped_convolution_impl::getSpatialRank(*this); } - int64_t getChannelPosition() { - return (getChannelFirstAttr().getValue()) ? 2 : getSpatialRank() + 2; + SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getLayoutsEnums() { + SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> layouts; + for (auto attr : (*this).getLayoutsAttr().getValue()) { + std::string layoutStr = cast<StringAttr>(attr).getValue().str(); + SmallVector<::mlir::utils::GroupedConvDim> layout(layoutStr.size()); + for (size_t i = 0; i < layoutStr.size(); i++) { + auto maybeDimEnum = ::mlir::utils::symbolizeGroupedConvDim(layoutStr.substr(i, 1).c_str()); + assert(maybeDimEnum); + layout[i] = maybeDimEnum.value(); + } + layouts.push_back(layout); + } + return layouts; + } + + int64_t getOutputChannelPosition() { + return 2; + } + + int64_t getInputChannelPosition() { + return 2; + } + + int64_t getInputGroupsPosition() { + return 1; } }]; } diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td index 4200343ce3e132..c7c5d617f6492c 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td @@ -20,4 +20,16 @@ def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [ let cppNamespace = "::mlir::utils"; } +def GroupedConvDim : I32EnumAttr<"GroupedConvDim", "Convolution dim", + [ + I32EnumAttrCase<"n", 0>, // batch + I32EnumAttrCase<"g", 1>, // group + I32EnumAttrCase<"f", 2>, // feature (output channel) + I32EnumAttrCase<"s", 3>, // spatial + I32EnumAttrCase<"c", 4> // channel (input channel) + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::utils"; +} + #endif // STRUCTURED_OPS_UTILS diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index c736bd064bded2..b98de5ee259e45 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -679,9 +679,11 @@ ArrayAttr mlir::linalg::detail::grouped_convolution_impl::getIteratorTypes( ArrayAttr mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps( - MLIRContext *ctx, int64_t numSpatial, int64_t channelPos, + MLIRContext *ctx, int64_t numSpatial, + const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts, const SmallVectorImpl<int64_t> &strides, const SmallVectorImpl<int64_t> &dilations) { + assert(layouts.size() == 3 && "expected 3 layouts: image, filter, init"); // Domain: (n, g, f, os, c, ks) AffineExpr n = getAffineDimExpr(0, ctx); @@ -695,26 +697,51 @@ mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps( llvm::seq<int64_t>(numSpatial + 4, 2 * (numSpatial + 1) + 2), [&](int64_t d) { return getAffineDimExpr(d, ctx); })); - // Initialze operand accesses in nw order and insert c according to channel - // position - SmallVector<AffineExpr> inExprs = {n}, outExprs = {n}; - SmallVector<AffineExpr> gc = {g, c}; - SmallVector<AffineExpr> gf = {g, f}; - SmallVector<AffineExpr> gfc = {g, f, c}; + SmallVector<AffineExpr> inSpatials; + inSpatials.reserve(numSpatial); for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, strides, dilations)) { - inExprs.push_back(sp * st + ksp * di); - outExprs.push_back(sp); + inSpatials.push_back(sp * st + ksp * di); } - SmallVector<AffineExpr> kExprs(ks); - inExprs.insert(inExprs.begin() + channelPos - 1, gc.begin(), gc.end()); - kExprs.insert(channelPos == 0 ? kExprs.begin() - : kExprs.begin() + channelPos - 2, - gfc.begin(), gfc.end()); - outExprs.insert(outExprs.begin() + channelPos - 1, gf.begin(), gf.end()); + + auto getExprs = [&](const SmallVector<utils::GroupedConvDim> &layout, + const SmallVector<AffineExpr> &spatials) { + SmallVector<AffineExpr> exprs(layout.size()); + int64_t spatialDim; + for (const auto &[i, dim] : llvm::enumerate(layout)) { + switch (dim) { + case utils::GroupedConvDim::n: + exprs[i] = n; + break; + case utils::GroupedConvDim::g: + exprs[i] = g; + break; + case utils::GroupedConvDim::f: + exprs[i] = f; + break; + case utils::GroupedConvDim::s: + exprs[i] = spatials[0]; + spatialDim = i; + break; + case utils::GroupedConvDim::c: + exprs[i] = c; + break; + default: + assert(false); + } + } + if (spatials.size() > 1) + exprs.insert(exprs.begin() + spatialDim + 1, spatials.begin() + 1, + spatials.end()); + return exprs; + }; + SmallVector<AffineExpr> inExprs = getExprs(layouts[0], inSpatials); + SmallVector<AffineExpr> kExprs = getExprs(layouts[1], ks); + SmallVector<AffineExpr> outExprs = getExprs(layouts[2], s); SmallVector<AffineMap> maps( - {AffineMap::get(4 + 2 * numSpatial, 0, inExprs, ctx), - AffineMap::get(4 + 2 * numSpatial, 0, kExprs, ctx), - AffineMap::get(4 + 2 * numSpatial, 0, outExprs, ctx)}); + {AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[0], inSpatials), + ctx), + AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[1], ks), ctx), + AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[2], s), ctx)}); return Builder(ctx).getAffineMapArrayAttr(maps); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index c31b17082b8900..29a3f39c8696c9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1761,7 +1761,7 @@ ArrayAttr GroupedConvNDOp::getIndexingMaps() { return cached; cached = detail::grouped_convolution_impl::createCommonIndexingMaps( - getContext(), getSpatialRank(), getChannelPosition(), getStridesVector(), + getContext(), getSpatialRank(), getLayoutsEnums(), getStridesVector(), getDilationsVector()); (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached); diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index 9d3444fe2ce9cb..876fdc9b11dc27 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -232,10 +232,11 @@ func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> { // CHECK-DAG: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x2x20x8x8x8xf32> // CHECK: memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32> to memref<64x2x20x8x8x8xf32> // CHECK: linalg.grouped_conv_nd -// CHECK-SAME: {channel_first = true, dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>} +// CHECK-SAME: dilations = dense<2> : tensor<3xi64> +// CHECK-SAME: strides = dense<3> : tensor<3xi64>} // CHECK-SAME: ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x2x16x26x26x26xf32>, memref<2x20x16x3x3x3xf32>) // CHECK-SAME: outs(%[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32>) func.func @gen_grouped_3D_channel_first_tensor(%arg0: tensor<64x2x16x26x26x26xf32>, %arg1: tensor<2x20x16x3x3x3xf32>, %arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> { - %0 = linalg.grouped_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> + %0 = linalg.grouped_conv_nd {strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> return %0 : tensor<64x2x20x8x8x8xf32> } diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index bd728edd1ec715..24177a3a8d7fa6 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -4,8 +4,8 @@ // CHECK-LABEL: func @gen_grouped_1D_channel_first_memref func.func @gen_grouped_1D_channel_first_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) { - // CHECK: grouped_conv_nd {{.*}}channel_first = true - linalg.grouped_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>) + // CHECK: grouped_conv_nd + linalg.grouped_conv_nd ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>) return } diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir index 50065ccb18cb84..475e2565ec5f94 100644 --- a/mlir/test/Dialect/Linalg/tile-conv.mlir +++ b/mlir/test/Dialect/Linalg/tile-conv.mlir @@ -52,7 +52,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0)[s0] -> (d0 + s0 - 1)> func.func @grouped_conv_2D(%arg0 : memref<?x?x?x?x?xf32>, %arg1 : memref<?x?x?x?x?xf32>, %arg2 : memref<?x?x?x?x?xf32>) { - linalg.grouped_conv_nd ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>) + linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1 : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) outs(%arg2 : memref<?x?x?x?x?xf32>) return } @@ -98,6 +98,6 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[SVIN:.*]] = memref.subview %[[ARG0]][%[[I]], %[[J]], 0, %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[IN_CHANNELS]], %[[T9]], %[[T10]]] // CHECK-DAG: %[[SVKER:.*]] = memref.subview %[[ARG1]][%[[J]], %[[K]], 0, 0, 0] [%[[T5]], %[[T6]], %[[IN_CHANNELS]], %[[KW]], %[[KH]]] // CHECK-DAG: %[[SVOUT:.*]] = memref.subview %[[ARG2]][%[[I]], %[[J]], %[[K]], %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[T6]], %[[T7]], %[[T8]]] -// CHECK: linalg.grouped_conv_nd {channel_first = true} +// CHECK: linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]} // CHECK-SAME: ins(%[[SVIN]], %[[SVKER]] // CHECK-SAME: outs(%[[SVOUT]] \ No newline at end of file _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits