llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Fabian Mora (fabianmcg)

<details>
<summary>Changes</summary>

This patch adds the `gather`, `masked_load`, `masked_store`, and `scatter` 
operations to the `ptr` dialect. It also implements translation from these 
operations to LLVM intrinsics:
- ptr.gather -&gt; llvm.masked.gather
- ptr.masked_load -&gt; llvm.masked.load  
- ptr.masked_store -&gt; llvm.masked.store
- ptr.scatter -&gt; llvm.masked.scatter

Example:
```mlir
llvm.func @<!-- -->mixed_masked_ops_address_spaces(%ptr: 
!ptr.ptr&lt;#llvm.address_space&lt;3&gt;&gt;, %ptrs: 
vector&lt;4x!ptr.ptr&lt;#llvm.address_space&lt;3&gt;&gt;&gt;, 
                                          %mask: vector&lt;4xi1&gt;, %value: 
vector&lt;4xf64&gt;, %passthrough: vector&lt;4xf64&gt;) {
  %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : 
vector&lt;4x!ptr.ptr&lt;#llvm.address_space&lt;3&gt;&gt;&gt; -&gt; 
vector&lt;4xf64&gt;
  ptr.scatter %value, %ptrs, %mask alignment = 8 : vector&lt;4xf64&gt;, 
vector&lt;4x!ptr.ptr&lt;#llvm.address_space&lt;3&gt;&gt;&gt;
  %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 8 : 
!ptr.ptr&lt;#llvm.address_space&lt;3&gt;&gt; -&gt; vector&lt;4xf64&gt;
  ptr.masked_store %value, %ptr, %mask alignment = 8 : vector&lt;4xf64&gt;, 
!ptr.ptr&lt;#llvm.address_space&lt;3&gt;&gt;
  llvm.return
}
```
Translates to:
```llvm
define void @<!-- -->mixed_masked_ops_address_spaces(ptr addrspace(3) %0, &lt;4 
x ptr addrspace(3)&gt; %1, &lt;4 x i1&gt; %2, &lt;4 x double&gt; %3, &lt;4 x 
double&gt; %4) {
  %6 = call &lt;4 x double&gt; @<!-- -->llvm.masked.gather.v4f64.v4p3(&lt;4 x 
ptr addrspace(3)&gt; %1, i32 8, &lt;4 x i1&gt; %2, &lt;4 x double&gt; %4)
  call void @<!-- -->llvm.masked.scatter.v4f64.v4p3(&lt;4 x double&gt; %3, 
&lt;4 x ptr addrspace(3)&gt; %1, i32 8, &lt;4 x i1&gt; %2)
  %7 = call &lt;4 x double&gt; @<!-- -->llvm.masked.load.v4f64.p3(ptr 
addrspace(3) %0, i32 8, &lt;4 x i1&gt; %2, &lt;4 x double&gt; %4)
  call void @<!-- -->llvm.masked.store.v4f64.p3(&lt;4 x double&gt; %3, ptr 
addrspace(3) %0, i32 8, &lt;4 x i1&gt; %2)
  ret void
}
```

---

Patch is 35.92 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/156368.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td (+236-2) 
- (modified) mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp (+143-13) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp 
(+119) 
- (modified) mlir/test/Dialect/Ptr/ops.mlir (+70) 
- (modified) mlir/test/Target/LLVMIR/ptr.mlir (+114) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td 
b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 1c88efced950e..170513d57c7be 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -17,6 +17,46 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/OpAsmInterface.td"
 
+//===----------------------------------------------------------------------===//
+// Common props
+//===----------------------------------------------------------------------===//
+
+def AlignmentProp : OptionalProp<I64Prop>;
+
+//===----------------------------------------------------------------------===//
+// Common types
+//===----------------------------------------------------------------------===//
+
+// A shaped value type with value semantics and rank.
+class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
+  ShapedContainerType<allowedTypes,
+    /*containerPred=*/And<[HasValueSemanticsPred] # preds>,
+    /*descr=*/[{A shaped type with value semantics and rank.}],
+    /*cppType=*/"::mlir::ShapedType">;
+
+// A shaped pointer type with value semantics and rank.
+class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
+
+// A shaped value type of rank 1 of any element type.
+def Ptr_Any1DType :
+  Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
+
+// A shaped value type of rank 1 of `i1` element type.
+def Ptr_Mask1DType :
+  Ptr_ShapedValueType<[I1], [HasAnyRankOfPred<[1]>]>;
+
+// A shaped value type of rank 1 of `i1` element type.
+def Ptr_Ptr1DType :
+  Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
+
+// Gets the type ID of a type.  
+class TypeIDType<string name> :
+    StrFunc<"$" # name # ".getType().getTypeID()">;
+
+// Checks that all type IDs match.
+class AllTypeIDsMatch<list<string> names> :
+    AllMatchSameOperatorTrait<names, TypeIDType<"_self">.result, "type IDs">;
+
 
//===----------------------------------------------------------------------===//
 // FromPtrOp
 
//===----------------------------------------------------------------------===//
@@ -56,6 +96,58 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_GatherOp : Pointer_Op<"gather", [
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
+      ::llvm::cast<ShapedType>($_self).clone(
+        IntegerType::get($_self.getContext(), 1))
+    }]>,
+    AllTypesMatch<["result", "passthrough"]>,
+    // Check the shapes are compatible and both use the same shaped container
+    // type.
+    AllShapesMatch<["result", "ptrs"]>, AllTypeIDsMatch<["result", "ptrs"]>
+  ]> {
+  let summary = "Gather operation";
+  let description = [{
+    The `gather` operation performs conditional loads from multiple memory
+    locations specified by `ptrs` based on a mask `mask`. Elements of the
+    result corresponding to masked-off lanes are taken from the passthrough
+    operand.
+
+    The mask operand is a shaped type of `i1` elements that must have the same
+    shape as the result type.
+
+    Examples:
+    ```mlir
+    // Gather values from multiple memory locations
+    %result = ptr.gather %ptrs, %mask, %passthrough :
+      vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
+
+    // Gather with alignment
+    %result = ptr.gather %ptrs, %mask, %passthrough alignment = 8 :
+      vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
+    ```
+  }];
+  let arguments = (ins Ptr_Ptr1DType:$ptrs,
+                       Ptr_Mask1DType:$mask,
+                       Ptr_Any1DType:$passthrough,
+                       AlignmentProp:$alignment);
+  let results = (outs Ptr_Any1DType:$result);
+  let assemblyFormat = [{
+    $ptrs `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
+    attr-dict `:` qualified(type($ptrs)) `->` type($result)
+  }];
+  let builders = [
+  OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask,
+      "Value":$passthrough, CArg<"unsigned", "0">:$alignment)>
+  ];
+  let hasVerifier = 1;
+}
+
 
//===----------------------------------------------------------------------===//
 // GetMetadataOp
 
//===----------------------------------------------------------------------===//
@@ -122,8 +214,6 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
 // LoadOp
 
//===----------------------------------------------------------------------===//
 
-def AlignmentProp : OptionalProp<I64Prop>;
-
 def Ptr_LoadOp : Pointer_Op<"load", [
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
   ]> {
@@ -184,6 +274,150 @@ def Ptr_LoadOp : Pointer_Op<"load", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MaskedLoadOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_MaskedLoadOp : Pointer_Op<"masked_load", [
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
+      ::llvm::cast<ShapedType>($_self).clone(
+        IntegerType::get($_self.getContext(), 1))
+    }]>,
+    AllTypesMatch<["result", "passthrough"]>
+  ]> {
+  let summary = "Masked load operation";
+  let description = [{
+    The `masked_load` operation performs a conditional load from memory based
+    on  a mask. Elements of the result corresponding to masked-off lanes are
+    taken from the passthrough operand.
+
+    The mask operand is a shaped type of `i1` elements that must have the same
+    shape as the result type.
+
+    Examples:
+    ```mlir
+    // Masked load with passthrough on vectors
+    %result = ptr.masked_load %ptr, %mask, %passthrough :
+      !ptr.ptr<#ptr.generic_space> -> vector<4xf32>
+
+    // Masked load with passthrough on tensors
+    %result = ptr.masked_load %ptr, %mask, %passthrough :
+      !ptr.ptr<#ptr.generic_space> -> tensor<4xf32>
+    ```
+  }];
+  let arguments = (ins Ptr_PtrType:$ptr,
+                       Ptr_Mask1DType:$mask,
+                       Ptr_Any1DType:$passthrough,
+                       AlignmentProp:$alignment);
+  let results = (outs Ptr_Any1DType:$result);
+  let assemblyFormat = [{
+    $ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
+    attr-dict `:` qualified(type($ptr)) `->` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask,
+      "Value":$passthrough, CArg<"unsigned", "0">:$alignment)>
+  ];
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// MaskedStoreOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    TypesMatchWith<"value and mask must be compatible", "value", "mask", [{
+      ::llvm::cast<ShapedType>($_self).clone(
+        IntegerType::get($_self.getContext(), 1))
+    }]>
+  ]> {
+  let summary = "Masked store operation";
+  let description = [{
+    The `masked_store` operation performs a conditional store to memory based
+    on  a mask. Only elements corresponding to set bits in the mask are written
+    to memory.
+
+    The mask operand is a shaped type of `i1` elements that must have the same
+    shape as the value being stored.
+
+    Examples:
+    ```mlir
+    // Masked store
+    ptr.masked_store %value, %ptr, %mask :
+      vector<4xf32>, !ptr.ptr<#ptr.generic_space>
+
+    // Masked store with alignment
+    ptr.masked_store %value, %ptr, %mask alignment = 8 :
+      vector<4xf32>, !ptr.ptr<#ptr.generic_space>
+    ```
+  }];
+
+  let arguments = (ins Ptr_Any1DType:$value,
+                       Ptr_PtrType:$ptr,
+                       Ptr_Mask1DType:$mask,
+                       AlignmentProp:$alignment);
+  let assemblyFormat = [{
+    $value `,` $ptr `,` $mask (`alignment` `=` $alignment^)? attr-dict `:`
+    type($value) `,` qualified(type($ptr))
+  }];
+  let builders = [
+    OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask,
+      CArg<"unsigned", "0">:$alignment)>
+  ];
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ScatterOp : Pointer_Op<"scatter", [
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    TypesMatchWith<"value and mask must be compatible", "value", "mask", [{
+      ::llvm::cast<ShapedType>($_self).clone(
+        IntegerType::get($_self.getContext(), 1))
+    }]>,
+    // Check the shapes are compatible and both use the same shaped container
+    // type.
+    AllShapesMatch<["value", "ptrs"]>, AllTypeIDsMatch<["value", "ptrs"]>
+  ]> {
+  let summary = "Scatter operation";
+  let description = [{
+    The `scatter` operation performs a conditional store of a value `value` to
+    multiple memory locations specified by `ptrs` based on a mask `mask`.
+
+    Only elements corresponding to set bits in the mask are written to memory.
+    The mask operand is a shaped type of `i1` elements that must have the same
+    shape as the value being stored.
+
+    Examples:
+    ```mlir
+    // Scatter values to multiple memory locations
+    ptr.scatter %value, %ptrs, %mask :
+      vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
+    
+    // Scatter with alignment
+    ptr.scatter %value, %ptrs, %mask alignment = 8 :
+      vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
+    ```
+  }];
+  let arguments = (ins Ptr_Any1DType:$value,
+                       Ptr_Ptr1DType:$ptrs,
+                       Ptr_Mask1DType:$mask,
+                       AlignmentProp:$alignment);
+  let assemblyFormat = [{
+    $value `,` $ptrs `,` $mask  (`alignment` `=` $alignment^)?
+    attr-dict `:` type($value) `,` qualified(type($ptrs))
+  }];
+  let builders = [
+    OpBuilder<(ins "Value":$value, "Value":$ptrs, "Value":$mask,
+      CArg<"unsigned", "0">:$alignment)>
+  ];
+  let hasVerifier = 1;
+}
+
 
//===----------------------------------------------------------------------===//
 // StoreOp
 
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp 
b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 6987926db7e5c..81ae4efd8ec87 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -39,6 +39,23 @@ void PtrDialect::initialize() {
       >();
 }
 
+//===----------------------------------------------------------------------===//
+// Common helper functions.
+//===----------------------------------------------------------------------===//
+
+/// Verifies that the alignment attribute is a power of 2 if present.
+static LogicalResult
+verifyAlignment(std::optional<int64_t> alignment,
+                function_ref<InFlightDiagnostic()> emitError) {
+  if (!alignment)
+    return success();
+  if (alignment.value() <= 0)
+    return emitError() << "alignment must be positive";
+  if (!llvm::isPowerOf2_64(alignment.value()))
+    return emitError() << "alignment must be a power of 2";
+  return success();
+}
+
 
//===----------------------------------------------------------------------===//
 // FromPtrOp
 
//===----------------------------------------------------------------------===//
@@ -84,6 +101,39 @@ LogicalResult FromPtrOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+void GatherOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  // Gather performs reads from multiple memory locations specified by ptrs
+  effects.emplace_back(MemoryEffects::Read::get(), &getPtrsMutable());
+}
+
+LogicalResult GatherOp::verify() {
+  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+  // Verify that the pointer type's memory space allows loads.
+  MemorySpaceAttrInterface ms =
+      cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
+  DataLayout dataLayout = DataLayout::closest(*this);
+  if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
+                      getAlignment(), dataLayout, emitDiag))
+    return failure();
+
+  // Verify the alignment.
+  return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void GatherOp::build(OpBuilder &builder, OperationState &state, Type 
resultType,
+                     Value ptrs, Value mask, Value passthrough,
+                     unsigned alignment) {
+  build(builder, state, resultType, ptrs, mask, passthrough,
+        alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
+
 
//===----------------------------------------------------------------------===//
 // LoadOp
 
//===----------------------------------------------------------------------===//
@@ -107,19 +157,6 @@ verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> 
unsupportedOrderings) {
   return success();
 }
 
-/// Verifies that the alignment attribute is a power of 2 if present.
-static LogicalResult
-verifyAlignment(std::optional<int64_t> alignment,
-                function_ref<InFlightDiagnostic()> emitError) {
-  if (!alignment)
-    return success();
-  if (alignment.value() <= 0)
-    return emitError() << "alignment must be positive";
-  if (!llvm::isPowerOf2_64(alignment.value()))
-    return emitError() << "alignment must be a power of 2";
-  return success();
-}
-
 void LoadOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -158,6 +195,99 @@ void LoadOp::build(OpBuilder &builder, OperationState 
&state, Type type,
         isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering,
         syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
 }
+//===----------------------------------------------------------------------===//
+// MaskedLoadOp
+//===----------------------------------------------------------------------===//
+
+void MaskedLoadOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  // MaskedLoad performs reads from the memory location specified by ptr.
+  effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
+}
+
+LogicalResult MaskedLoadOp::verify() {
+  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+  // Verify that the pointer type's memory space allows loads.
+  MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+  DataLayout dataLayout = DataLayout::closest(*this);
+  if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
+                      getAlignment(), dataLayout, emitDiag))
+    return failure();
+
+  // Verify the alignment.
+  return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void MaskedLoadOp::build(OpBuilder &builder, OperationState &state,
+                         Type resultType, Value ptr, Value mask,
+                         Value passthrough, unsigned alignment) {
+  build(builder, state, resultType, ptr, mask, passthrough,
+        alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
+
+//===----------------------------------------------------------------------===//
+// MaskedStoreOp
+//===----------------------------------------------------------------------===//
+
+void MaskedStoreOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  // MaskedStore performs writes to the memory location specified by ptr
+  effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
+}
+
+LogicalResult MaskedStoreOp::verify() {
+  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+  // Verify that the pointer type's memory space allows stores.
+  MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+  DataLayout dataLayout = DataLayout::closest(*this);
+  if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
+                       getAlignment(), dataLayout, emitDiag))
+    return failure();
+
+  // Verify the alignment.
+  return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void MaskedStoreOp::build(OpBuilder &builder, OperationState &state,
+                          Value value, Value ptr, Value mask,
+                          unsigned alignment) {
+  build(builder, state, value, ptr, mask,
+        alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+void ScatterOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  // Scatter performs writes to multiple memory locations specified by ptrs
+  effects.emplace_back(MemoryEffects::Write::get(), &getPtrsMutable());
+}
+
+LogicalResult ScatterOp::verify() {
+  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+  // Verify that the pointer type's memory space allows stores.
+  MemorySpaceAttrInterface ms =
+      cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
+  DataLayout dataLayout = DataLayout::closest(*this);
+  if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
+                       getAlignment(), dataLayout, emitDiag))
+    return failure();
+
+  // Verify the alignment.
+  return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void ScatterOp::build(OpBuilder &builder, OperationState &state, Value value,
+                      Value ptrs, Value mask, unsigned alignment) {
+  build(builder, state, value, ptrs, mask,
+        alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
 
 
//===----------------------------------------------------------------------===//
 // StoreOp
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp 
b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index 906e19901617b..ede3d0de90996 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -207,6 +207,112 @@ convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, 
llvm::IRBuilderBase &builder,
   return success();
 }
 
+/// Convert ptr.gather operation
+static LogicalResult
+convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
+                LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs());
+  llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask());
+  llvm::Value *passthrough =
+      moduleTranslation.lookupValue(gatherOp.getPassthrough());
+
+  if (!ptrs || !mask || !passthrough)
+    return gatherOp.emitError("Failed to lookup operands");
+
+  // Convert result type to LLVM type.
+  llvm::Type *resultType =
+      moduleTranslation.convertType(gatherOp.getResult().getType());
+  if (!resultType)
+    return gatherOp.emitError("Failed to convert result type");
+
+  // Get the alignment.
+  llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0));
+
+  // Create the masked gather intrinsic call.
+  llvm::Value *result = builder.CreateMaskedGather(
+      resultType, ptrs, alignment.valueOrOne(), mask, passthrough);
+
+  moduleTranslation.mapValue(gatherOp.getResult(), result);
+  return success();
+}
+
+/// Convert ptr.masked_load operation
+static LogicalResult
+convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr());
+  llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask());
+  llvm::Value *passthrough =
+      moduleTranslation.lookupValue(maskedLoadOp.getPassthrough());
+
+  if (!ptr || !mask || !passthrough)
+    return maskedLoadOp.emitError("Failed to lookup operands");
+
+  // Conver...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/156368
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to