https://github.com/manbearian updated
https://github.com/llvm/llvm-project/pull/78484
>From b83074fa260d2ce4b876b71d507224cb9476a944 Mon Sep 17 00:00:00 2001
From: Ian Bearman
Date: Wed, 17 Jan 2024 17:54:25 +
Subject: [PATCH] port fixes from local llvm
---
.../IR/BufferizableOpInterface.h | 13 +++
.../BufferizableOpInterfaceImpl.cpp | 13 ---
.../IR/BufferizableOpInterface.cpp| 12 +++---
.../Bufferization/IR/BufferizationOps.cpp | 4 +-
.../FuncBufferizableOpInterfaceImpl.cpp | 4 +-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 38 ++-
.../BufferizableOpInterfaceImpl.cpp | 8 ++--
mlir/test/Dialect/Linalg/collapse-dim.mlir| 14 +++
8 files changed, 70 insertions(+), 36 deletions(-)
diff --git
a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 63e2d19e68ef97c..478cdab8298754c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -257,6 +257,9 @@ struct BufferizationOptions {
/// Parameters: Value, memory space, bufferization options
using UnknownTypeConverterFn = std::function;
+ // Produce a MemorySpace attribute from a tensor type
+ using GetMemorySpaceFn =
+ std::function(TensorType t)>;
BufferizationOptions();
@@ -351,6 +354,16 @@ struct BufferizationOptions {
/// used.
UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
+ // Use during type conversion to determine the memory space for memref based
+ // on the originanl tensor type
+ GetMemorySpaceFn getMemorySpaceFn = nullptr;
+
+ std::optional getMemorySpace(TensorType t) const {
+if (getMemorySpaceFn)
+ return getMemorySpaceFn(t);
+return defaultMemorySpace;
+ }
+
/// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
/// Should be used only with `testAnalysisOnly = true`.
unsigned analysisFuzzerSeed = 0;
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f69b2557eec922e..337ac0c0761440e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,17 +26,18 @@ struct ConstantOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto constantOp = cast(op);
+auto type = constantOp.getType().dyn_cast();
+
+// Only ranked tensors are supported.
+if (!type)
+ return failure();
Attribute memorySpace;
-if (options.defaultMemorySpace.has_value())
- memorySpace = *options.defaultMemorySpace;
+if (options.getMemorySpace(type))
+ memorySpace = *options.getMemorySpace(type);
else
return constantOp->emitError("could not infer memory space");
-// Only ranked tensors are supported.
-if (!isa(constantOp.getType()))
- return failure();
-
// Only constants inside a module are supported.
auto moduleOp = constantOp->getParentOfType();
if (!moduleOp)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 4b1dfee4a2b926f..1a849155abed028 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -682,11 +682,11 @@ bufferization::getBufferType(Value value, const
BufferizationOptions &options,
return bufferizableOp.getBufferType(value, options, invocationStack);
// Op is not bufferizable.
- if (!options.defaultMemorySpace.has_value())
+ auto memSpace = options.getMemorySpace(value.getType().cast());
+ if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{},
- *options.defaultMemorySpace);
+ return getMemRefType(value, options, /*layout=*/{}, *memSpace);
}
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -943,11 +943,11 @@ FailureOr
bufferization::detail::defaultGetBufferType(
// If we do not know the memory space and there is no default memory space,
// report a failure.
- if (!options.defaultMemorySpace.has_value())
+ auto memSpace = options.getMemorySpace(value.getType().cast());
+ if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{},
- *options.defaultMemorySpace);
+ return getMemRefType(value, options, /*layout=*/{}, *memSpace);
}
bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
b/mlir/lib/Dialect/B