https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/156368
This patch adds the `gather`, `masked_load`, `masked_store`, and `scatter` operations to the `ptr` dialect. It also implements translation from these operations to LLVM intrinsics: - ptr.gather -> llvm.masked.gather - ptr.masked_load -> llvm.masked.load - ptr.masked_store -> llvm.masked.store - ptr.scatter -> llvm.masked.scatter Example: ```mlir llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>, %mask: vector<4xi1>, %value: vector<4xf64>, %passthrough: vector<4xf64>) { %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf64> ptr.scatter %value, %ptrs, %mask alignment = 8 : vector<4xf64>, vector<4x!ptr.ptr<#llvm.address_space<3>>> %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 8 : !ptr.ptr<#llvm.address_space<3>> -> vector<4xf64> ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>> llvm.return } ``` Translates to: ```llvm define void @mixed_masked_ops_address_spaces(ptr addrspace(3) %0, <4 x ptr addrspace(3)> %1, <4 x i1> %2, <4 x double> %3, <4 x double> %4) { %6 = call <4 x double> @llvm.masked.gather.v4f64.v4p3(<4 x ptr addrspace(3)> %1, i32 8, <4 x i1> %2, <4 x double> %4) call void @llvm.masked.scatter.v4f64.v4p3(<4 x double> %3, <4 x ptr addrspace(3)> %1, i32 8, <4 x i1> %2) %7 = call <4 x double> @llvm.masked.load.v4f64.p3(ptr addrspace(3) %0, i32 8, <4 x i1> %2, <4 x double> %4) call void @llvm.masked.store.v4f64.p3(<4 x double> %3, ptr addrspace(3) %0, i32 8, <4 x i1> %2) ret void } ``` >From 1898a4301ca6f9ecd2d125217e28cce2abd20e52 Mon Sep 17 00:00:00 2001 From: Fabian Mora <6982088+fabian...@users.noreply.github.com> Date: Sun, 31 Aug 2025 12:01:19 +0000 Subject: [PATCH] Add load, store variant ops --- mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 238 +++++++++++++++++- mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 156 +++++++++++- .../Dialect/Ptr/PtrToLLVMIRTranslation.cpp | 119 +++++++++ mlir/test/Dialect/Ptr/ops.mlir | 70 ++++++ mlir/test/Target/LLVMIR/ptr.mlir | 114 +++++++++ 5 files changed, 682 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 1c88efced950e..170513d57c7be 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -17,6 +17,46 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/OpAsmInterface.td" +//===----------------------------------------------------------------------===// +// Common props +//===----------------------------------------------------------------------===// + +def AlignmentProp : OptionalProp<I64Prop>; + +//===----------------------------------------------------------------------===// +// Common types +//===----------------------------------------------------------------------===// + +// A shaped value type with value semantics and rank. +class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> : + ShapedContainerType<allowedTypes, + /*containerPred=*/And<[HasValueSemanticsPred] # preds>, + /*descr=*/[{A shaped type with value semantics and rank.}], + /*cppType=*/"::mlir::ShapedType">; + +// A shaped pointer type with value semantics and rank. +class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>; + +// A shaped value type of rank 1 of any element type. +def Ptr_Any1DType : + Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>; + +// A shaped value type of rank 1 of `i1` element type. +def Ptr_Mask1DType : + Ptr_ShapedValueType<[I1], [HasAnyRankOfPred<[1]>]>; + +// A shaped value type of rank 1 of `i1` element type. +def Ptr_Ptr1DType : + Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>; + +// Gets the type ID of a type. +class TypeIDType<string name> : + StrFunc<"$" # name # ".getType().getTypeID()">; + +// Checks that all type IDs match. +class AllTypeIDsMatch<list<string> names> : + AllMatchSameOperatorTrait<names, TypeIDType<"_self">.result, "type IDs">; + //===----------------------------------------------------------------------===// // FromPtrOp //===----------------------------------------------------------------------===// @@ -56,6 +96,58 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +def Ptr_GatherOp : Pointer_Op<"gather", [ + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + TypesMatchWith<"result and mask must be compatible", "result", "mask", [{ + ::llvm::cast<ShapedType>($_self).clone( + IntegerType::get($_self.getContext(), 1)) + }]>, + AllTypesMatch<["result", "passthrough"]>, + // Check the shapes are compatible and both use the same shaped container + // type. + AllShapesMatch<["result", "ptrs"]>, AllTypeIDsMatch<["result", "ptrs"]> + ]> { + let summary = "Gather operation"; + let description = [{ + The `gather` operation performs conditional loads from multiple memory + locations specified by `ptrs` based on a mask `mask`. Elements of the + result corresponding to masked-off lanes are taken from the passthrough + operand. + + The mask operand is a shaped type of `i1` elements that must have the same + shape as the result type. + + Examples: + ```mlir + // Gather values from multiple memory locations + %result = ptr.gather %ptrs, %mask, %passthrough : + vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32> + + // Gather with alignment + %result = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : + vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32> + ``` + }]; + let arguments = (ins Ptr_Ptr1DType:$ptrs, + Ptr_Mask1DType:$mask, + Ptr_Any1DType:$passthrough, + AlignmentProp:$alignment); + let results = (outs Ptr_Any1DType:$result); + let assemblyFormat = [{ + $ptrs `,` $mask `,` $passthrough (`alignment` `=` $alignment^)? + attr-dict `:` qualified(type($ptrs)) `->` type($result) + }]; + let builders = [ + OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask, + "Value":$passthrough, CArg<"unsigned", "0">:$alignment)> + ]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // GetMetadataOp //===----------------------------------------------------------------------===// @@ -122,8 +214,6 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ // LoadOp //===----------------------------------------------------------------------===// -def AlignmentProp : OptionalProp<I64Prop>; - def Ptr_LoadOp : Pointer_Op<"load", [ DeclareOpInterfaceMethods<MemoryEffectsOpInterface> ]> { @@ -184,6 +274,150 @@ def Ptr_LoadOp : Pointer_Op<"load", [ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// MaskedLoadOp +//===----------------------------------------------------------------------===// + +def Ptr_MaskedLoadOp : Pointer_Op<"masked_load", [ + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + TypesMatchWith<"result and mask must be compatible", "result", "mask", [{ + ::llvm::cast<ShapedType>($_self).clone( + IntegerType::get($_self.getContext(), 1)) + }]>, + AllTypesMatch<["result", "passthrough"]> + ]> { + let summary = "Masked load operation"; + let description = [{ + The `masked_load` operation performs a conditional load from memory based + on a mask. Elements of the result corresponding to masked-off lanes are + taken from the passthrough operand. + + The mask operand is a shaped type of `i1` elements that must have the same + shape as the result type. + + Examples: + ```mlir + // Masked load with passthrough on vectors + %result = ptr.masked_load %ptr, %mask, %passthrough : + !ptr.ptr<#ptr.generic_space> -> vector<4xf32> + + // Masked load with passthrough on tensors + %result = ptr.masked_load %ptr, %mask, %passthrough : + !ptr.ptr<#ptr.generic_space> -> tensor<4xf32> + ``` + }]; + let arguments = (ins Ptr_PtrType:$ptr, + Ptr_Mask1DType:$mask, + Ptr_Any1DType:$passthrough, + AlignmentProp:$alignment); + let results = (outs Ptr_Any1DType:$result); + let assemblyFormat = [{ + $ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)? + attr-dict `:` qualified(type($ptr)) `->` type($result) + }]; + let builders = [ + OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask, + "Value":$passthrough, CArg<"unsigned", "0">:$alignment)> + ]; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// MaskedStoreOp +//===----------------------------------------------------------------------===// + +def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [ + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + TypesMatchWith<"value and mask must be compatible", "value", "mask", [{ + ::llvm::cast<ShapedType>($_self).clone( + IntegerType::get($_self.getContext(), 1)) + }]> + ]> { + let summary = "Masked store operation"; + let description = [{ + The `masked_store` operation performs a conditional store to memory based + on a mask. Only elements corresponding to set bits in the mask are written + to memory. + + The mask operand is a shaped type of `i1` elements that must have the same + shape as the value being stored. + + Examples: + ```mlir + // Masked store + ptr.masked_store %value, %ptr, %mask : + vector<4xf32>, !ptr.ptr<#ptr.generic_space> + + // Masked store with alignment + ptr.masked_store %value, %ptr, %mask alignment = 8 : + vector<4xf32>, !ptr.ptr<#ptr.generic_space> + ``` + }]; + + let arguments = (ins Ptr_Any1DType:$value, + Ptr_PtrType:$ptr, + Ptr_Mask1DType:$mask, + AlignmentProp:$alignment); + let assemblyFormat = [{ + $value `,` $ptr `,` $mask (`alignment` `=` $alignment^)? attr-dict `:` + type($value) `,` qualified(type($ptr)) + }]; + let builders = [ + OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask, + CArg<"unsigned", "0">:$alignment)> + ]; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +def Ptr_ScatterOp : Pointer_Op<"scatter", [ + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + TypesMatchWith<"value and mask must be compatible", "value", "mask", [{ + ::llvm::cast<ShapedType>($_self).clone( + IntegerType::get($_self.getContext(), 1)) + }]>, + // Check the shapes are compatible and both use the same shaped container + // type. + AllShapesMatch<["value", "ptrs"]>, AllTypeIDsMatch<["value", "ptrs"]> + ]> { + let summary = "Scatter operation"; + let description = [{ + The `scatter` operation performs a conditional store of a value `value` to + multiple memory locations specified by `ptrs` based on a mask `mask`. + + Only elements corresponding to set bits in the mask are written to memory. + The mask operand is a shaped type of `i1` elements that must have the same + shape as the value being stored. + + Examples: + ```mlir + // Scatter values to multiple memory locations + ptr.scatter %value, %ptrs, %mask : + vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>> + + // Scatter with alignment + ptr.scatter %value, %ptrs, %mask alignment = 8 : + vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>> + ``` + }]; + let arguments = (ins Ptr_Any1DType:$value, + Ptr_Ptr1DType:$ptrs, + Ptr_Mask1DType:$mask, + AlignmentProp:$alignment); + let assemblyFormat = [{ + $value `,` $ptrs `,` $mask (`alignment` `=` $alignment^)? + attr-dict `:` type($value) `,` qualified(type($ptrs)) + }]; + let builders = [ + OpBuilder<(ins "Value":$value, "Value":$ptrs, "Value":$mask, + CArg<"unsigned", "0">:$alignment)> + ]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index 6987926db7e5c..81ae4efd8ec87 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -39,6 +39,23 @@ void PtrDialect::initialize() { >(); } +//===----------------------------------------------------------------------===// +// Common helper functions. +//===----------------------------------------------------------------------===// + +/// Verifies that the alignment attribute is a power of 2 if present. +static LogicalResult +verifyAlignment(std::optional<int64_t> alignment, + function_ref<InFlightDiagnostic()> emitError) { + if (!alignment) + return success(); + if (alignment.value() <= 0) + return emitError() << "alignment must be positive"; + if (!llvm::isPowerOf2_64(alignment.value())) + return emitError() << "alignment must be a power of 2"; + return success(); +} + //===----------------------------------------------------------------------===// // FromPtrOp //===----------------------------------------------------------------------===// @@ -84,6 +101,39 @@ LogicalResult FromPtrOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +void GatherOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + // Gather performs reads from multiple memory locations specified by ptrs + effects.emplace_back(MemoryEffects::Read::get(), &getPtrsMutable()); +} + +LogicalResult GatherOp::verify() { + auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; + + // Verify that the pointer type's memory space allows loads. + MemorySpaceAttrInterface ms = + cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); + if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic, + getAlignment(), dataLayout, emitDiag)) + return failure(); + + // Verify the alignment. + return verifyAlignment(getAlignment(), emitDiag); +} + +void GatherOp::build(OpBuilder &builder, OperationState &state, Type resultType, + Value ptrs, Value mask, Value passthrough, + unsigned alignment) { + build(builder, state, resultType, ptrs, mask, passthrough, + alignment ? std::optional<int64_t>(alignment) : std::nullopt); +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// @@ -107,19 +157,6 @@ verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> unsupportedOrderings) { return success(); } -/// Verifies that the alignment attribute is a power of 2 if present. -static LogicalResult -verifyAlignment(std::optional<int64_t> alignment, - function_ref<InFlightDiagnostic()> emitError) { - if (!alignment) - return success(); - if (alignment.value() <= 0) - return emitError() << "alignment must be positive"; - if (!llvm::isPowerOf2_64(alignment.value())) - return emitError() << "alignment must be a power of 2"; - return success(); -} - void LoadOp::getEffects( SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) { @@ -158,6 +195,99 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type, isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering, syncscope.empty() ? nullptr : builder.getStringAttr(syncscope)); } +//===----------------------------------------------------------------------===// +// MaskedLoadOp +//===----------------------------------------------------------------------===// + +void MaskedLoadOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + // MaskedLoad performs reads from the memory location specified by ptr. + effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable()); +} + +LogicalResult MaskedLoadOp::verify() { + auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; + // Verify that the pointer type's memory space allows loads. + MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); + if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic, + getAlignment(), dataLayout, emitDiag)) + return failure(); + + // Verify the alignment. + return verifyAlignment(getAlignment(), emitDiag); +} + +void MaskedLoadOp::build(OpBuilder &builder, OperationState &state, + Type resultType, Value ptr, Value mask, + Value passthrough, unsigned alignment) { + build(builder, state, resultType, ptr, mask, passthrough, + alignment ? std::optional<int64_t>(alignment) : std::nullopt); +} + +//===----------------------------------------------------------------------===// +// MaskedStoreOp +//===----------------------------------------------------------------------===// + +void MaskedStoreOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + // MaskedStore performs writes to the memory location specified by ptr + effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable()); +} + +LogicalResult MaskedStoreOp::verify() { + auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; + // Verify that the pointer type's memory space allows stores. + MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); + if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic, + getAlignment(), dataLayout, emitDiag)) + return failure(); + + // Verify the alignment. + return verifyAlignment(getAlignment(), emitDiag); +} + +void MaskedStoreOp::build(OpBuilder &builder, OperationState &state, + Value value, Value ptr, Value mask, + unsigned alignment) { + build(builder, state, value, ptr, mask, + alignment ? std::optional<int64_t>(alignment) : std::nullopt); +} + +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +void ScatterOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + // Scatter performs writes to multiple memory locations specified by ptrs + effects.emplace_back(MemoryEffects::Write::get(), &getPtrsMutable()); +} + +LogicalResult ScatterOp::verify() { + auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; + + // Verify that the pointer type's memory space allows stores. + MemorySpaceAttrInterface ms = + cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); + if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic, + getAlignment(), dataLayout, emitDiag)) + return failure(); + + // Verify the alignment. + return verifyAlignment(getAlignment(), emitDiag); +} + +void ScatterOp::build(OpBuilder &builder, OperationState &state, Value value, + Value ptrs, Value mask, unsigned alignment) { + build(builder, state, value, ptrs, mask, + alignment ? std::optional<int64_t>(alignment) : std::nullopt); +} //===----------------------------------------------------------------------===// // StoreOp diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp index 906e19901617b..ede3d0de90996 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp @@ -207,6 +207,112 @@ convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder, return success(); } +/// Convert ptr.gather operation +static LogicalResult +convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs()); + llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask()); + llvm::Value *passthrough = + moduleTranslation.lookupValue(gatherOp.getPassthrough()); + + if (!ptrs || !mask || !passthrough) + return gatherOp.emitError("Failed to lookup operands"); + + // Convert result type to LLVM type. + llvm::Type *resultType = + moduleTranslation.convertType(gatherOp.getResult().getType()); + if (!resultType) + return gatherOp.emitError("Failed to convert result type"); + + // Get the alignment. + llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0)); + + // Create the masked gather intrinsic call. + llvm::Value *result = builder.CreateMaskedGather( + resultType, ptrs, alignment.valueOrOne(), mask, passthrough); + + moduleTranslation.mapValue(gatherOp.getResult(), result); + return success(); +} + +/// Convert ptr.masked_load operation +static LogicalResult +convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr()); + llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask()); + llvm::Value *passthrough = + moduleTranslation.lookupValue(maskedLoadOp.getPassthrough()); + + if (!ptr || !mask || !passthrough) + return maskedLoadOp.emitError("Failed to lookup operands"); + + // Convert result type to LLVM type. + llvm::Type *resultType = + moduleTranslation.convertType(maskedLoadOp.getResult().getType()); + if (!resultType) + return maskedLoadOp.emitError("Failed to convert result type"); + + // Get the alignment. + llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0)); + + // Create the masked load intrinsic call. + llvm::Value *result = builder.CreateMaskedLoad( + resultType, ptr, alignment.valueOrOne(), mask, passthrough); + + moduleTranslation.mapValue(maskedLoadOp.getResult(), result); + return success(); +} + +/// Convert ptr.masked_store operation +static LogicalResult +convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue()); + llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr()); + llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask()); + + if (!value || !ptr || !mask) + return maskedStoreOp.emitError("Failed to lookup operands"); + + // Get the value type. + llvm::Type *valueType = value->getType(); + if (!valueType) + return maskedStoreOp.emitError("Failed to get value type"); + + // Get the alignment. + llvm::MaybeAlign alignment(maskedStoreOp.getAlignment().value_or(0)); + + // Create the masked store intrinsic call. + builder.CreateMaskedStore(value, ptr, alignment.valueOrOne(), mask); + return success(); +} + +/// Convert ptr.scatter operation +static LogicalResult +convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue()); + llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs()); + llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask()); + + if (!value || !ptrs || !mask) + return scatterOp.emitError("Failed to lookup operands"); + + // Get the value type + llvm::Type *valueType = value->getType(); + if (!valueType) + return scatterOp.emitError("Failed to get value type"); + + // Get the alignment. + llvm::MaybeAlign alignment(scatterOp.getAlignment().value_or(0)); + + // Create the masked scatter intrinsic call. + builder.CreateMaskedScatter(value, ptrs, alignment.valueOrOne(), mask); + return success(); +} + /// Implementation of the dialect interface that converts operations belonging /// to the `ptr` dialect to LLVM IR. class PtrDialectLLVMIRTranslationInterface @@ -233,6 +339,19 @@ class PtrDialectLLVMIRTranslationInterface .Case<TypeOffsetOp>([&](TypeOffsetOp typeOffsetOp) { return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation); }) + .Case<GatherOp>([&](GatherOp gatherOp) { + return convertGatherOp(gatherOp, builder, moduleTranslation); + }) + .Case<MaskedLoadOp>([&](MaskedLoadOp maskedLoadOp) { + return convertMaskedLoadOp(maskedLoadOp, builder, moduleTranslation); + }) + .Case<MaskedStoreOp>([&](MaskedStoreOp maskedStoreOp) { + return convertMaskedStoreOp(maskedStoreOp, builder, + moduleTranslation); + }) + .Case<ScatterOp>([&](ScatterOp scatterOp) { + return convertScatterOp(scatterOp, builder, moduleTranslation); + }) .Default([&](Operation *op) { return op->emitError("Translation for operation '") << op->getName() << "' is not implemented."; diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir index 3f3ad05c46acc..bde2fb22b6aa0 100644 --- a/mlir/test/Dialect/Ptr/ops.mlir +++ b/mlir/test/Dialect/Ptr/ops.mlir @@ -56,3 +56,73 @@ func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<2>>, %arg1: f32, %arg2 ptr.store %arg2, %arg0 atomic release alignment = 8 : i64, !ptr.ptr<#llvm.address_space<2>> return } + +/// Test gather operations +func.func @gather_ops(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4xi1>, %passthrough: vector<4xf32>) -> vector<4xf32> { + %0 = ptr.gather %ptrs, %mask, %passthrough : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32> + %1 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32> + return %0 : vector<4xf32> +} + +/// Test gather operations with tensors +func.func @gather_ops_tensor(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %mask: tensor<8xi1>, %passthrough: tensor<8xi32>) -> tensor<8xi32> { + %0 = ptr.gather %ptrs, %mask, %passthrough : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi32> + %1 = ptr.gather %ptrs, %mask, %passthrough alignment = 4 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi32> + return %0 : tensor<8xi32> +} + +/// Test scatter operations +func.func @scatter_ops(%value: vector<4xf32>, %ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4xi1>) { + ptr.scatter %value, %ptrs, %mask : vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>> + ptr.scatter %value, %ptrs, %mask alignment = 16 : vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>> + return +} + +/// Test scatter operations with tensors +func.func @scatter_ops_tensor(%value: tensor<8xi64>, %ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %mask: tensor<8xi1>) { + ptr.scatter %value, %ptrs, %mask : tensor<8xi64>, tensor<8x!ptr.ptr<#ptr.generic_space>> + ptr.scatter %value, %ptrs, %mask alignment = 8 : tensor<8xi64>, tensor<8x!ptr.ptr<#ptr.generic_space>> + return +} + +/// Test masked load operations +func.func @masked_load_ops(%ptr: !ptr.ptr<#ptr.generic_space>, %mask: vector<4xi1>, %passthrough: vector<4xf32>) -> vector<4xf32> { + %0 = ptr.masked_load %ptr, %mask, %passthrough : !ptr.ptr<#ptr.generic_space> -> vector<4xf32> + %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 16 : !ptr.ptr<#ptr.generic_space> -> vector<4xf32> + return %0 : vector<4xf32> +} + +/// Test masked load operations with tensors +func.func @masked_load_ops_tensor(%ptr: !ptr.ptr<#ptr.generic_space>, %mask: tensor<8xi1>, %passthrough: tensor<8xi32>) -> tensor<8xi32> { + %0 = ptr.masked_load %ptr, %mask, %passthrough : !ptr.ptr<#ptr.generic_space> -> tensor<8xi32> + %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 4 : !ptr.ptr<#ptr.generic_space> -> tensor<8xi32> + return %0 : tensor<8xi32> +} + +/// Test masked store operations +func.func @masked_store_ops(%value: vector<4xf32>, %ptr: !ptr.ptr<#ptr.generic_space>, %mask: vector<4xi1>) { + ptr.masked_store %value, %ptr, %mask : vector<4xf32>, !ptr.ptr<#ptr.generic_space> + ptr.masked_store %value, %ptr, %mask alignment = 32 : vector<4xf32>, !ptr.ptr<#ptr.generic_space> + return +} + +/// Test masked store operations with tensors +func.func @masked_store_ops_tensor(%value: tensor<8xi64>, %ptr: !ptr.ptr<#ptr.generic_space>, %mask: tensor<8xi1>) { + ptr.masked_store %value, %ptr, %mask : tensor<8xi64>, !ptr.ptr<#ptr.generic_space> + ptr.masked_store %value, %ptr, %mask alignment = 8 : tensor<8xi64>, !ptr.ptr<#ptr.generic_space> + return +} + +/// Test operations with LLVM address space +func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>, + %mask: vector<4xi1>, %value: vector<4xf32>, %passthrough: vector<4xf32>) -> vector<4xf32> { + // Gather from shared memory (address space 3) + %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 4 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf32> + // Scatter to shared memory + ptr.scatter %value, %ptrs, %mask alignment = 4 : vector<4xf32>, vector<4x!ptr.ptr<#llvm.address_space<3>>> + // Masked load from shared memory + %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 4 : !ptr.ptr<#llvm.address_space<3>> -> vector<4xf32> + // Masked store to shared memory + ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<4xf32>, !ptr.ptr<#llvm.address_space<3>> + return %0 : vector<4xf32> +} diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir index 6e3b365b862e2..545bec5979b2d 100644 --- a/mlir/test/Target/LLVMIR/ptr.mlir +++ b/mlir/test/Target/LLVMIR/ptr.mlir @@ -89,3 +89,117 @@ llvm.func @store_ops(%arg0: !ptr.ptr<#llvm.address_space<0>>, %arg1: f32, %arg2: ptr.store volatile %arg3, %arg0 atomic syncscope("workgroup") release nontemporal alignment = 4 : i32, !ptr.ptr<#llvm.address_space<0>> llvm.return } + +// CHECK-LABEL: define <4 x float> @gather_ops +// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], <4 x i1> %[[MASK:.*]], <4 x float> %[[PASSTHROUGH:.*]]) { +// CHECK-NEXT: %[[V0:.*]] = call <4 x float> @llvm.masked.gather.v4f32.v4p0(<4 x ptr> %[[PTRS]], i32 1, <4 x i1> %[[MASK]], <4 x float> %[[PASSTHROUGH]]) +// CHECK-NEXT: %[[V1:.*]] = call <4 x float> @llvm.masked.gather.v4f32.v4p0(<4 x ptr> %[[PTRS]], i32 4, <4 x i1> %[[MASK]], <4 x float> %[[PASSTHROUGH]]) +// CHECK-NEXT: ret <4 x float> %[[V0]] +// CHECK-NEXT: } +llvm.func @gather_ops(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %mask: vector<4xi1>, %passthrough: vector<4xf32>) -> vector<4xf32> { + // Basic gather + %0 = ptr.gather %ptrs, %mask, %passthrough : vector<4x!ptr.ptr<#llvm.address_space<0>>> -> vector<4xf32> + // Gather with alignment + %1 = ptr.gather %ptrs, %mask, %passthrough alignment = 4 : vector<4x!ptr.ptr<#llvm.address_space<0>>> -> vector<4xf32> + llvm.return %0 : vector<4xf32> +} + +// CHECK-LABEL: define <8 x i32> @gather_ops_i32 +// CHECK-SAME: (<8 x ptr> %[[PTRS:.*]], <8 x i1> %[[MASK:.*]], <8 x i32> %[[PASSTHROUGH:.*]]) { +// CHECK-NEXT: %[[V0:.*]] = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %[[PTRS]], i32 8, <8 x i1> %[[MASK]], <8 x i32> %[[PASSTHROUGH]]) +// CHECK-NEXT: ret <8 x i32> %[[V0]] +// CHECK-NEXT: } +llvm.func @gather_ops_i32(%ptrs: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %mask: vector<8xi1>, %passthrough: vector<8xi32>) -> vector<8xi32> { + %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32> + llvm.return %0 : vector<8xi32> +} + +// CHECK-LABEL: define <4 x float> @masked_load_ops +// CHECK-SAME: (ptr %[[PTR:.*]], <4 x i1> %[[MASK:.*]], <4 x float> %[[PASSTHROUGH:.*]]) { +// CHECK-NEXT: %[[V0:.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %[[PTR]], i32 1, <4 x i1> %[[MASK]], <4 x float> %[[PASSTHROUGH]]) +// CHECK-NEXT: %[[V1:.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %[[PTR]], i32 16, <4 x i1> %[[MASK]], <4 x float> %[[PASSTHROUGH]]) +// CHECK-NEXT: ret <4 x float> %[[V0]] +// CHECK-NEXT: } +llvm.func @masked_load_ops(%ptr: !ptr.ptr<#llvm.address_space<0>>, %mask: vector<4xi1>, %passthrough: vector<4xf32>) -> vector<4xf32> { + // Basic masked load + %0 = ptr.masked_load %ptr, %mask, %passthrough : !ptr.ptr<#llvm.address_space<0>> -> vector<4xf32> + // Masked load with alignment + %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 16 : !ptr.ptr<#llvm.address_space<0>> -> vector<4xf32> + llvm.return %0 : vector<4xf32> +} + +// CHECK-LABEL: define <8 x i64> @masked_load_ops_i64 +// CHECK-SAME: (ptr %[[PTR:.*]], <8 x i1> %[[MASK:.*]], <8 x i64> %[[PASSTHROUGH:.*]]) { +// CHECK-NEXT: %[[V0:.*]] = call <8 x i64> @llvm.masked.load.v8i64.p0(ptr %[[PTR]], i32 8, <8 x i1> %[[MASK]], <8 x i64> %[[PASSTHROUGH]]) +// CHECK-NEXT: ret <8 x i64> %[[V0]] +// CHECK-NEXT: } +llvm.func @masked_load_ops_i64(%ptr: !ptr.ptr<#llvm.address_space<0>>, %mask: vector<8xi1>, %passthrough: vector<8xi64>) -> vector<8xi64> { + %0 = ptr.masked_load %ptr, %mask, %passthrough alignment = 8 : !ptr.ptr<#llvm.address_space<0>> -> vector<8xi64> + llvm.return %0 : vector<8xi64> +} + +// CHECK-LABEL: define void @masked_store_ops +// CHECK-SAME: (ptr %[[PTR:.*]], <4 x float> %[[VALUE:.*]], <4 x i1> %[[MASK:.*]]) { +// CHECK-NEXT: call void @llvm.masked.store.v4f32.p0(<4 x float> %[[VALUE]], ptr %[[PTR]], i32 1, <4 x i1> %[[MASK]]) +// CHECK-NEXT: call void @llvm.masked.store.v4f32.p0(<4 x float> %[[VALUE]], ptr %[[PTR]], i32 32, <4 x i1> %[[MASK]]) +// CHECK-NEXT: ret void +// CHECK-NEXT: } +llvm.func @masked_store_ops(%ptr: !ptr.ptr<#llvm.address_space<0>>, %value: vector<4xf32>, %mask: vector<4xi1>) { + // Basic masked store + ptr.masked_store %value, %ptr, %mask : vector<4xf32>, !ptr.ptr<#llvm.address_space<0>> + // Masked store with alignment + ptr.masked_store %value, %ptr, %mask alignment = 32 : vector<4xf32>, !ptr.ptr<#llvm.address_space<0>> + llvm.return +} + +// CHECK-LABEL: define void @masked_store_ops_i16 +// CHECK-SAME: (ptr %[[PTR:.*]], <8 x i16> %[[VALUE:.*]], <8 x i1> %[[MASK:.*]]) { +// CHECK-NEXT: call void @llvm.masked.store.v8i16.p0(<8 x i16> %[[VALUE]], ptr %[[PTR]], i32 4, <8 x i1> %[[MASK]]) +// CHECK-NEXT: ret void +// CHECK-NEXT: } +llvm.func @masked_store_ops_i16(%ptr: !ptr.ptr<#llvm.address_space<0>>, %value: vector<8xi16>, %mask: vector<8xi1>) { + ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<8xi16>, !ptr.ptr<#llvm.address_space<0>> + llvm.return +} + +// CHECK-LABEL: define void @scatter_ops +// CHECK-SAME: (<4 x float> %[[VALUE:.*]], <4 x ptr> %[[PTRS:.*]], <4 x i1> %[[MASK:.*]]) { +// CHECK-NEXT: call void @llvm.masked.scatter.v4f32.v4p0(<4 x float> %[[VALUE]], <4 x ptr> %[[PTRS]], i32 1, <4 x i1> %[[MASK]]) +// CHECK-NEXT: call void @llvm.masked.scatter.v4f32.v4p0(<4 x float> %[[VALUE]], <4 x ptr> %[[PTRS]], i32 8, <4 x i1> %[[MASK]]) +// CHECK-NEXT: ret void +// CHECK-NEXT: } +llvm.func @scatter_ops(%value: vector<4xf32>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %mask: vector<4xi1>) { + // Basic scatter + ptr.scatter %value, %ptrs, %mask : vector<4xf32>, vector<4x!ptr.ptr<#llvm.address_space<0>>> + // Scatter with alignment + ptr.scatter %value, %ptrs, %mask alignment = 8 : vector<4xf32>, vector<4x!ptr.ptr<#llvm.address_space<0>>> + llvm.return +} + +// CHECK-LABEL: define void @scatter_ops_i64 +// CHECK-SAME: (<8 x i64> %[[VALUE:.*]], <8 x ptr> %[[PTRS:.*]], <8 x i1> %[[MASK:.*]]) { +// CHECK-NEXT: call void @llvm.masked.scatter.v8i64.v8p0(<8 x i64> %[[VALUE]], <8 x ptr> %[[PTRS]], i32 16, <8 x i1> %[[MASK]]) +// CHECK-NEXT: ret void +// CHECK-NEXT: } +llvm.func @scatter_ops_i64(%value: vector<8xi64>, %ptrs: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %mask: vector<8xi1>) { + ptr.scatter %value, %ptrs, %mask alignment = 16 : vector<8xi64>, vector<8x!ptr.ptr<#llvm.address_space<0>>> + llvm.return +} + +// CHECK-LABEL: define void @mixed_masked_ops_address_spaces +// CHECK-SAME: (ptr addrspace(3) %[[PTR_SHARED:.*]], <4 x ptr addrspace(3)> %[[PTRS_SHARED:.*]], <4 x i1> %[[MASK:.*]], <4 x double> %[[VALUE_F64:.*]], <4 x double> %[[PASSTHROUGH_F64:.*]]) { +// CHECK-NEXT: %[[V0:.*]] = call <4 x double> @llvm.masked.gather.v4f64.v4p3(<4 x ptr addrspace(3)> %[[PTRS_SHARED]], i32 8, <4 x i1> %[[MASK]], <4 x double> %[[PASSTHROUGH_F64]]) +// CHECK-NEXT: call void @llvm.masked.scatter.v4f64.v4p3(<4 x double> %[[VALUE_F64]], <4 x ptr addrspace(3)> %[[PTRS_SHARED]], i32 8, <4 x i1> %[[MASK]]) +// CHECK-NEXT: %[[V1:.*]] = call <4 x double> @llvm.masked.load.v4f64.p3(ptr addrspace(3) %[[PTR_SHARED]], i32 8, <4 x i1> %[[MASK]], <4 x double> %[[PASSTHROUGH_F64]]) +// CHECK-NEXT: call void @llvm.masked.store.v4f64.p3(<4 x double> %[[VALUE_F64]], ptr addrspace(3) %[[PTR_SHARED]], i32 8, <4 x i1> %[[MASK]]) +// CHECK-NEXT: ret void +// CHECK-NEXT: } +llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>, + %mask: vector<4xi1>, %value: vector<4xf64>, %passthrough: vector<4xf64>) { + // Test with shared memory address space (3) and f64 elements + %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf64> + ptr.scatter %value, %ptrs, %mask alignment = 8 : vector<4xf64>, vector<4x!ptr.ptr<#llvm.address_space<3>>> + %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 8 : !ptr.ptr<#llvm.address_space<3>> -> vector<4xf64> + ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>> + llvm.return +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits