https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/123177
>From 137705661c184ea1530982c19163341933ab421e Mon Sep 17 00:00:00 2001 From: Jean Perier <jper...@nvidia.com> Date: Wed, 15 Jan 2025 09:09:53 -0800 Subject: [PATCH 1/4] [mlir][LLVM] add argument and result attributes to llvm.call --- llvm/include/llvm/IR/InstrTypes.h | 11 +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 4 +- .../include/mlir/Target/LLVMIR/ModuleImport.h | 8 ++- .../mlir/Target/LLVMIR/ModuleTranslation.h | 9 ++- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 67 +++++++++++++------ .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 21 ++++++ mlir/lib/Target/LLVMIR/ModuleImport.cpp | 35 ++++++++++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 52 ++++++++++---- mlir/test/Dialect/LLVMIR/invalid.mlir | 2 + mlir/test/Dialect/LLVMIR/roundtrip.mlir | 20 ++++++ .../LLVMIR/Import/call-argument-attributes.ll | 25 +++++++ .../LLVMIR/call-argument-attributes.mlir | 17 +++++ 12 files changed, 230 insertions(+), 41 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll create mode 100644 mlir/test/Target/LLVMIR/call-argument-attributes.mlir diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h index b8d9cc10292f4a..0e391325eebdce 100644 --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1490,6 +1490,11 @@ class CallBase : public Instruction { Attrs = Attrs.addRetAttribute(getContext(), Attr); } + /// Adds attributes to the return value. + void addRetAttrs(const AttrBuilder &B) { + Attrs = Attrs.addRetAttributes(getContext(), B); + } + /// Adds the attribute to the indicated argument void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) { assert(ArgNo < arg_size() && "Out of bounds"); @@ -1502,6 +1507,12 @@ class CallBase : public Instruction { Attrs = Attrs.addParamAttribute(getContext(), ArgNo, Attr); } + /// Adds attributes to the indicated argument + void addParamAttrs(unsigned ArgNo, const AttrBuilder &B) { + assert(ArgNo < arg_size() && "Out of bounds"); + Attrs = Attrs.addParamAttributes(getContext(), ArgNo, B); + } + /// removes the attribute from the list of attributes. void removeAttributeAtIndex(unsigned i, Attribute::AttrKind Kind) { Attrs = Attrs.removeAttributeAtIndex(getContext(), i, Kind); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index b2281536aa40b6..85f5c6cc8cca07 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -755,7 +755,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", VariadicOfVariadic<LLVM_Type, "op_bundle_sizes">:$op_bundle_operands, DenseI32ArrayAttr:$op_bundle_sizes, - OptionalAttr<ArrayAttr>:$op_bundle_tags); + OptionalAttr<ArrayAttr>:$op_bundle_tags, + OptionalAttr<DictArrayAttr>:$arg_attrs, + OptionalAttr<DictArrayAttr>:$res_attrs); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional<LLVM_Type>:$result); diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 33c9af7c6335a4..86e1d6a04cd096 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -326,14 +326,18 @@ class ModuleImport { SmallVectorImpl<Type> &types, SmallVectorImpl<Value> &operands, bool allowInlineAsm = false); - /// Converts the parameter attributes attached to `func` and adds them to the - /// `funcOp`. + /// Converts the parameter and result attributes attached to `func` and adds + /// them to the `funcOp`. void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp, OpBuilder &builder); /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding /// DictionaryAttr for the LLVM dialect. DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, OpBuilder &builder); + /// Converts the parameter and result attributes attached to `call` and adds + /// them to the `callOp`. + void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp, + OpBuilder &builder); /// Returns the builtin type equivalent to the given LLVM dialect type or /// nullptr if there is no equivalent. The returned type can be used to create /// an attribute for a GlobalOp or a ConstantOp. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 1b62437761ed9d..88fc17ca4fda24 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -228,6 +228,11 @@ class ModuleTranslation { /*recordInsertions=*/false); } + /// Translates parameter attributes of a call and adds them to the returned + /// AttrBuilder. Returns failure if any of the translations failed. + FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOp callOp, int argIdx, + DictionaryAttr paramAttrs); + /// Gets the named metadata in the LLVM IR module being constructed, creating /// it if it does not exist. llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name); @@ -346,8 +351,8 @@ class ModuleTranslation { convertDialectAttributes(Operation *op, ArrayRef<llvm::Instruction *> instructions); - /// Translates parameter attributes and adds them to the returned AttrBuilder. - /// Returns failure if any of the translations failed. + /// Translates parameter attributes of a function and adds them to the + /// returned AttrBuilder. Returns failure if any of the translations failed. FailureOr<llvm::AttrBuilder> convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index ef1e0222e05f06..6c4988bac7813e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1033,6 +1033,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1060,6 +1061,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1073,6 +1075,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1087,6 +1090,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1331,42 +1335,52 @@ void CallOp::print(OpAsmPrinter &p) { getVarCalleeTypeAttrName(), getCConvAttrName(), getOperandSegmentSizesAttrName(), getOpBundleSizesAttrName(), - getOpBundleTagsAttrName()}); + getOpBundleTagsAttrName(), getArgAttrsAttrName(), + getResAttrsAttrName()}); p << " : "; if (!isDirect) p << getOperand(0).getType() << ", "; // Reconstruct the function MLIR function type from operand and result types. - p.printFunctionalType(args.getTypes(), getResultTypes()); + call_interface_impl::printFunctionSignature( + p, *this, args.getTypes(), /*isVariadic=*/false, getResultTypes()); } /// Parses the type of a call operation and resolves the operands if the parsing /// succeeds. Returns failure otherwise. static ParseResult parseCallTypeAndResolveOperands( OpAsmParser &parser, OperationState &result, bool isDirect, - ArrayRef<OpAsmParser::UnresolvedOperand> operands) { + ArrayRef<OpAsmParser::UnresolvedOperand> operands, + SmallVectorImpl<DictionaryAttr> &argAttrs, + SmallVectorImpl<DictionaryAttr> &resultAttrs) { SMLoc trailingTypesLoc = parser.getCurrentLocation(); SmallVector<Type> types; - if (parser.parseColonTypeList(types)) + if (parser.parseColon()) return failure(); - - if (isDirect && types.size() != 1) - return parser.emitError(trailingTypesLoc, - "expected direct call to have 1 trailing type"); - if (!isDirect && types.size() != 2) - return parser.emitError(trailingTypesLoc, - "expected indirect call to have 2 trailing types"); - - auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val()); - if (!funcType) + if (!isDirect) { + types.emplace_back(); + if (parser.parseType(types.back())) + return failure(); + if (parser.parseOptionalComma()) + return parser.emitError( + trailingTypesLoc, "expected indirect call to have 2 trailing types"); + } + SmallVector<Type> argTypes; + SmallVector<Type> resTypes; + if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs, + resTypes, resultAttrs)) { + if (isDirect) + return parser.emitError(trailingTypesLoc, + "expected direct call to have 1 trailing types"); return parser.emitError(trailingTypesLoc, "expected trailing function type"); - if (funcType.getNumResults() > 1) + } + + if (resTypes.size() > 1) return parser.emitError(trailingTypesLoc, "expected function with 0 or 1 result"); - if (funcType.getNumResults() == 1 && - llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0))) + if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0])) return parser.emitError(trailingTypesLoc, "expected a non-void result type"); @@ -1374,12 +1388,12 @@ static ParseResult parseCallTypeAndResolveOperands( // indirect calls, while the types list is emtpy for direct calls. // Append the function input types to resolve the call operation // operands. - llvm::append_range(types, funcType.getInputs()); + llvm::append_range(types, argTypes); if (parser.resolveOperands(operands, types, parser.getNameLoc(), result.operands)) return failure(); - if (funcType.getNumResults() != 0) - result.addTypes(funcType.getResults()); + if (resTypes.size() != 0) + result.addTypes(resTypes); return success(); } @@ -1493,8 +1507,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); // Parse the trailing type list and resolve the operands. - if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands)) + SmallVector<DictionaryAttr> argAttrs; + SmallVector<DictionaryAttr> resultAttrs; + if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands, + argAttrs, resultAttrs)) return failure(); + call_interface_impl::addArgAndResultAttrs( + parser.getBuilder(), result, argAttrs, resultAttrs, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, opBundleOperandTypes, getOpBundleSizesAttrName(result.name))) @@ -1714,7 +1734,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); // Parse the trailing type list and resolve the function operands. - if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands)) + SmallVector<DictionaryAttr> argAttrs; + SmallVector<DictionaryAttr> resultAttrs; + if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands, + argAttrs, resultAttrs)) return failure(); if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, opBundleOperandTypes, diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 2084e527773ca8..52f42df60f0015 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -265,6 +265,27 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, if (callOp.getWillReturnAttr()) call->addFnAttr(llvm::Attribute::WillReturn); + if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) + for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { + if (auto argAttrs = llvm::cast<DictionaryAttr>(argAttrsAttr)) { + FailureOr<llvm::AttrBuilder> attrBuilder = + moduleTranslation.convertParameterAttrs(callOp, argIdx, argAttrs); + if (failed(attrBuilder)) + return failure(); + call->addParamAttrs(argIdx, *attrBuilder); + } + } + + ArrayAttr resAttrsArray = callOp.getResAttrsAttr(); + if (resAttrsArray && resAttrsArray.size() == 1) + if (auto resAttrs = llvm::cast<DictionaryAttr>(resAttrsArray[0])) { + FailureOr<llvm::AttrBuilder> attrBuilder = + moduleTranslation.convertParameterAttrs(callOp, -1, resAttrs); + if (failed(attrBuilder)) + return failure(); + call->addRetAttrs(*attrBuilder); + } + if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) { llvm::MemoryEffects memEffects = llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem, diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index eba86f06d09056..f65bf6584d51f2 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1641,6 +1641,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { if (callInst->hasFnAttr(llvm::Attribute::WillReturn)) callOp.setWillReturn(true); + // Handle parameter and result attributes. + convertParameterAttributes(callInst, callOp, builder); + llvm::MemoryEffects memEffects = callInst->getMemoryEffects(); ModRefInfo othermem = convertModRefInfoFromLLVM( memEffects.getModRef(llvm::MemoryEffects::Location::Other)); @@ -2084,6 +2087,38 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func, builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder))); } +void ModuleImport::convertParameterAttributes(llvm::CallBase *call, + CallOpInterface callOp, + OpBuilder &builder) { + auto llvmAttrs = call->getAttributes(); + SmallVector<llvm::AttributeSet> llvmArgAttrsSet; + bool anyArgAttrs = false; + for (size_t i = 0, e = call->arg_size(); i < e; ++i) { + llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i)); + if (llvmArgAttrsSet.back().hasAttributes()) + anyArgAttrs = true; + } + auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) { + SmallVector<Attribute> attrs; + for (auto &dict : dictAttrs) + attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); + return builder.getArrayAttr(attrs); + }; + if (anyArgAttrs) { + SmallVector<DictionaryAttr> argAttrs; + for (auto &llvmArgAttrs : llvmArgAttrsSet) + argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder)); + callOp.setArgAttrsAttr(getArrayAttr(argAttrs)); + } + + llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); + if (!llvmResAttr.hasAttributes()) + return; + SmallVector<DictionaryAttr, 1> resAttrs; + resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder)); + callOp.setResAttrsAttr(getArrayAttr(resAttrs)); +} + LogicalResult ModuleImport::processFunction(llvm::Function *func) { clearRegionState(); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 4367100e3aca68..b2d2c1cddca318 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1563,6 +1563,26 @@ static void convertFunctionKernelAttributes(LLVMFuncOp func, } } +static void convertParameterAttr(llvm::AttrBuilder &attrBuilder, + llvm::Attribute::AttrKind llvmKind, + NamedAttribute namedAttr, + ModuleTranslation &moduleTranslation) { + llvm::TypeSwitch<Attribute>(namedAttr.getValue()) + .Case<TypeAttr>([&](auto typeAttr) { + attrBuilder.addTypeAttr( + llvmKind, moduleTranslation.convertType(typeAttr.getValue())); + }) + .Case<IntegerAttr>([&](auto intAttr) { + attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt()); + }) + .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); }) + .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) { + attrBuilder.addConstantRangeAttr( + llvmKind, + llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper())); + }); +} + FailureOr<llvm::AttrBuilder> ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs) { @@ -1573,20 +1593,7 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, auto it = attrNameToKindMapping.find(namedAttr.getName()); if (it != attrNameToKindMapping.end()) { llvm::Attribute::AttrKind llvmKind = it->second; - - llvm::TypeSwitch<Attribute>(namedAttr.getValue()) - .Case<TypeAttr>([&](auto typeAttr) { - attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue())); - }) - .Case<IntegerAttr>([&](auto intAttr) { - attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt()); - }) - .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); }) - .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) { - attrBuilder.addConstantRangeAttr( - llvmKind, llvm::ConstantRange(rangeAttr.getLower(), - rangeAttr.getUpper())); - }); + convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this); } else if (namedAttr.getNameDialect()) { if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this))) return failure(); @@ -1596,6 +1603,23 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, return attrBuilder; } +FailureOr<llvm::AttrBuilder> +ModuleTranslation::convertParameterAttrs(CallOp, int argIdx, + DictionaryAttr paramAttrs) { + llvm::AttrBuilder attrBuilder(llvmModule->getContext()); + auto attrNameToKindMapping = getAttrNameToKindMapping(); + + for (auto namedAttr : paramAttrs) { + auto it = attrNameToKindMapping.find(namedAttr.getName()); + if (it != attrNameToKindMapping.end()) { + llvm::Attribute::AttrKind llvmKind = it->second; + convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this); + } + } + + return attrBuilder; +} + LogicalResult ModuleTranslation::convertFunctionSignatures() { // Declare all functions first because there may be function calls that form a // call graph with cycles, or global initializers that reference functions. diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 25806d9d0edd72..14cdcc06625c06 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -235,6 +235,7 @@ func.func @call_missing_ptr_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) { func.func private @standard_func_callee() func.func @call_missing_ptr_type(%arg : i8) { + // expected-error@+2 {{expected '('}} // expected-error@+1 {{expected direct call to have 1 trailing type}} llvm.call @standard_func_callee(%arg) : !llvm.ptr, (i8) -> (i8) llvm.return @@ -251,6 +252,7 @@ func.func @call_non_pointer_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) { // ----- func.func @call_non_function_type(%callee : !llvm.ptr, %arg : i8) { + // expected-error@+2 {{expected '('}} // expected-error@+1 {{expected trailing function type}} llvm.call %callee(%arg) : !llvm.ptr, !llvm.func<i8 (i8)> llvm.return diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 88660ce598f3c2..e565772f06b03c 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -941,3 +941,23 @@ llvm.func @test_assume_intr_with_opbundles(%arg0 : !llvm.ptr) { llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1 llvm.return } + +llvm.func @somefunc(i32, !llvm.ptr) + +// CHECK-LABEL: llvm.func @test_call_arg_attrs_direct( +// CHECK-SAME: %[[VAL_0:.*]]: i32, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) +llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) { + // CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + llvm.return +} + +// CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect( +// CHECK-SAME: %[[VAL_0:.*]]: i16, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr +llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 { + // CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + llvm.return %0 : i16 +} diff --git a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll new file mode 100644 index 00000000000000..8294579b48c63c --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll @@ -0,0 +1,25 @@ +; RUN: mlir-translate -import-llvm %s | FileCheck %s + +; CHECK-LABEL: llvm.func @somefunc(i32, !llvm.ptr) +declare void @somefunc(i32, ptr) + +; CHECK-LABEL: llvm.func @test_call_arg_attrs_direct( +; CHECK-SAME: %[[VAL_0:.*]]: i32, +; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) +llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) { +declare void @somefunc(i32, ptr) +; CHECK-LABEL: @test_call_arg_attrs_direct +define void @test_call_arg_attrs_direct(i32 %0, ptr %1) { + ; CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + call void @somefunc(i32 %0, ptr byval(i64) %1) + ret void +} + +; CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect( +; CHECK-SAME: %[[VAL_0:.*]]: i16, +; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr +define i16 @test_call_arg_attrs_indirect(i16 %0, ptr %1) { +; CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + %3 = tail call signext i16 %1(i16 noundef signext %0) + ret i16 %3 +} diff --git a/mlir/test/Target/LLVMIR/call-argument-attributes.mlir b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir new file mode 100644 index 00000000000000..89b1f29a68623b --- /dev/null +++ b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @somefunc(i32, !llvm.ptr) + +// CHECK-LABEL: define void @test_call_arg_attrs_direct +llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) { + // CHECK: call void @somefunc(i32 %{{.*}}, ptr byval(i64) %{{.*}}) + llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + llvm.return +} + +// CHECK-LABEL: define i16 @test_call_arg_attrs_indirec +llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 { + // CHECK: tail call signext i16 %{{.*}}(i16 noundef signext %{{.*}}) + %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + llvm.return %0 : i16 +} >From 879b03de74daffe4f83a5c72f76fbeed495f73bf Mon Sep 17 00:00:00 2001 From: Jean Perier <jper...@nvidia.com> Date: Thu, 16 Jan 2025 04:42:18 -0800 Subject: [PATCH 2/4] remove bogus extra lines in new test --- mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll index 8294579b48c63c..2c86ca6b03125e 100644 --- a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll +++ b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll @@ -6,9 +6,6 @@ declare void @somefunc(i32, ptr) ; CHECK-LABEL: llvm.func @test_call_arg_attrs_direct( ; CHECK-SAME: %[[VAL_0:.*]]: i32, ; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) -llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) { -declare void @somefunc(i32, ptr) -; CHECK-LABEL: @test_call_arg_attrs_direct define void @test_call_arg_attrs_direct(i32 %0, ptr %1) { ; CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> () call void @somefunc(i32 %0, ptr byval(i64) %1) >From 3328f681ad735cae0a5e2d06ab685bb4710fb994 Mon Sep 17 00:00:00 2001 From: Jean Perier <jper...@nvidia.com> Date: Thu, 16 Jan 2025 08:59:05 -0800 Subject: [PATCH 3/4] rename and change inheritance level --- .../mlir/Interfaces/CallImplementation.h | 5 +--- .../include/mlir/Interfaces/CallInterfaces.td | 29 +++++++------------ .../mlir/Interfaces/FunctionInterfaces.td | 2 +- mlir/lib/Interfaces/CallImplementation.cpp | 2 +- mlir/lib/Transforms/Utils/InliningUtils.cpp | 24 ++++++++------- 5 files changed, 28 insertions(+), 34 deletions(-) diff --git a/mlir/include/mlir/Interfaces/CallImplementation.h b/mlir/include/mlir/Interfaces/CallImplementation.h index 85e47f6b3dbbb9..2edc081bddf478 100644 --- a/mlir/include/mlir/Interfaces/CallImplementation.h +++ b/mlir/include/mlir/Interfaces/CallImplementation.h @@ -20,8 +20,6 @@ namespace mlir { -class OpWithArgumentAttributesInterface; - namespace call_interface_impl { /// Parse a function or call result list. @@ -65,8 +63,7 @@ ParseResult parseFunctionSignature(OpAsmParser &parser, /// -> function-result-list /// ssa-function-arg-list ::= ssa-function-arg (`,` ssa-function-arg)* /// ssa-function-arg ::= `%`name `:` type attribute-dict? -void printFunctionSignature(OpAsmPrinter &p, - OpWithArgumentAttributesInterface op, +void printFunctionSignature(OpAsmPrinter &p, ArgumentAttributesOpInterface op, TypeRange argTypes, bool isVariadic, TypeRange resultTypes, Region *body = nullptr, bool printEmptyResult = true); diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td index 80912a9762187e..1f2398387c044e 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -20,9 +20,10 @@ include "mlir/IR/OpBase.td" /// Interface for operations with arguments attributes (both call-like /// and callable operations). -def OpWithArgumentAttributesInterface : OpInterface<"OpWithArgumentAttributesInterface"> { +def ArgumentAttributesOpInterface : OpInterface<"ArgumentAttributesOpInterface"> { let description = [{ - A call-like or callable operation that may define attributes for its arguments. + A call-like or callable operation that can hold attributes for its arguments + and results. }]; let cppNamespace = "::mlir"; let methods = [ @@ -32,40 +33,34 @@ def OpWithArgumentAttributesInterface : OpInterface<"OpWithArgumentAttributesInt number to the number of arguments. Alternatively, the method can return null to indicate that the region has no argument attributes. }], - "::mlir::ArrayAttr", "getArgAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, + "::mlir::ArrayAttr", "getArgAttrsAttr">, InterfaceMethod<[{ Get the array of result attribute dictionaries. The method should return an array attribute containing only dictionary attributes equal in number to the number of results. Alternatively, the method can return null to indicate that the region has no result attributes. }], - "::mlir::ArrayAttr", "getResAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, + "::mlir::ArrayAttr", "getResAttrsAttr">, InterfaceMethod<[{ Set the array of argument attribute dictionaries. }], - "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>, + "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, InterfaceMethod<[{ Set the array of result attribute dictionaries. }], - "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>, + "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, InterfaceMethod<[{ Remove the array of argument attribute dictionaries. This is the same as setting all argument attributes to an empty dictionary. The method should return the removed attribute. }], - "::mlir::Attribute", "removeArgAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, + "::mlir::Attribute", "removeArgAttrsAttr">, InterfaceMethod<[{ Remove the array of result attribute dictionaries. This is the same as setting all result attributes to an empty dictionary. The method should return the removed attribute. }], - "::mlir::Attribute", "removeResAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, + "::mlir::Attribute", "removeResAttrsAttr"> ]; } @@ -74,8 +69,7 @@ def OpWithArgumentAttributesInterface : OpInterface<"OpWithArgumentAttributesInt // a call-like operation. This represents the destination of the call. /// Interface for call-like operations. -def CallOpInterface : OpInterface<"CallOpInterface", - [OpWithArgumentAttributesInterface]> { +def CallOpInterface : OpInterface<"CallOpInterface"> { let description = [{ A call-like operation is one that transfers control from one sub-routine to another. These operations may be traditional direct calls `call @foo`, or @@ -138,8 +132,7 @@ def CallOpInterface : OpInterface<"CallOpInterface", } /// Interface for callable operations. -def CallableOpInterface : OpInterface<"CallableOpInterface", - [OpWithArgumentAttributesInterface]> { +def CallableOpInterface : OpInterface<"CallableOpInterface"> { let description = [{ A callable operation is one who represents a potential sub-routine, and may be a target for a call-like operation (those providing the CallOpInterface diff --git a/mlir/include/mlir/Interfaces/FunctionInterfaces.td b/mlir/include/mlir/Interfaces/FunctionInterfaces.td index 697f951748c675..616785837e1452 100644 --- a/mlir/include/mlir/Interfaces/FunctionInterfaces.td +++ b/mlir/include/mlir/Interfaces/FunctionInterfaces.td @@ -22,7 +22,7 @@ include "mlir/Interfaces/CallInterfaces.td" //===----------------------------------------------------------------------===// def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ - Symbol, CallableOpInterface + Symbol, CallableOpInterface, ArgumentAttributesOpInterface ]> { let cppNamespace = "::mlir"; let description = [{ diff --git a/mlir/lib/Interfaces/CallImplementation.cpp b/mlir/lib/Interfaces/CallImplementation.cpp index 85eca609d8dc8d..974e779e32d30b 100644 --- a/mlir/lib/Interfaces/CallImplementation.cpp +++ b/mlir/lib/Interfaces/CallImplementation.cpp @@ -95,7 +95,7 @@ static void printFunctionResultList(OpAsmPrinter &p, TypeRange types, } void call_interface_impl::printFunctionSignature( - OpAsmPrinter &p, OpWithArgumentAttributesInterface op, TypeRange argTypes, + OpAsmPrinter &p, ArgumentAttributesOpInterface op, TypeRange argTypes, bool isVariadic, TypeRange resultTypes, Region *body, bool printEmptyResult) { bool isExternal = !body || body->empty(); diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index 0cae63c58ca7be..57a7931b56085a 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -193,11 +193,13 @@ static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, SmallVector<DictionaryAttr> argAttrs( callable.getCallableRegion()->getNumArguments(), builder.getDictionaryAttr({})); - if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) { - assert(arrayAttr.size() == argAttrs.size()); - for (auto [idx, attr] : llvm::enumerate(arrayAttr)) - argAttrs[idx] = cast<DictionaryAttr>(attr); - } + if (auto argAttrsOpInterface = + dyn_cast<ArgumentAttributesOpInterface>(callable.getOperation())) + if (ArrayAttr arrayAttr = argAttrsOpInterface.getArgAttrsAttr()) { + assert(arrayAttr.size() == argAttrs.size()); + for (auto [idx, attr] : llvm::enumerate(arrayAttr)) + argAttrs[idx] = cast<DictionaryAttr>(attr); + } // Run the argument attribute handler for the given argument and attribute. for (auto [blockArg, argAttr] : @@ -218,11 +220,13 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, // Unpack the result attributes if there are any. SmallVector<DictionaryAttr> resAttrs(results.size(), builder.getDictionaryAttr({})); - if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) { - assert(arrayAttr.size() == resAttrs.size()); - for (auto [idx, attr] : llvm::enumerate(arrayAttr)) - resAttrs[idx] = cast<DictionaryAttr>(attr); - } + if (auto argAttrsOpInterface = + dyn_cast<ArgumentAttributesOpInterface>(callable.getOperation())) + if (ArrayAttr arrayAttr = argAttrsOpInterface.getResAttrsAttr()) { + assert(arrayAttr.size() == resAttrs.size()); + for (auto [idx, attr] : llvm::enumerate(arrayAttr)) + resAttrs[idx] = cast<DictionaryAttr>(attr); + } // Run the result attribute handler for the given result and attribute. SmallVector<DictionaryAttr> resultAttributes; >From 854e43c16d73a3645cd224be2861836373079b1f Mon Sep 17 00:00:00 2001 From: Jean Perier <jper...@nvidia.com> Date: Thu, 16 Jan 2025 08:59:35 -0800 Subject: [PATCH 4/4] change inheritance level of new interface --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 1 + mlir/lib/Target/LLVMIR/ModuleImport.cpp | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 85f5c6cc8cca07..5c7a697107c237 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -699,6 +699,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<FastmathFlagsInterface>, DeclareOpInterfaceMethods<CallOpInterface>, + DeclareOpInterfaceMethods<ArgumentAttributesOpInterface>, DeclareOpInterfaceMethods<SymbolUserOpInterface>, DeclareOpInterfaceMethods<BranchWeightOpInterface>]> { let summary = "Call to an LLVM function."; diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index f65bf6584d51f2..f51b577c255660 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -2090,6 +2090,10 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func, void ModuleImport::convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp, OpBuilder &builder) { + auto argAttrsOpInterface = + dyn_cast<ArgumentAttributesOpInterface>(callOp.getOperation()); + if (!argAttrsOpInterface) + return; auto llvmAttrs = call->getAttributes(); SmallVector<llvm::AttributeSet> llvmArgAttrsSet; bool anyArgAttrs = false; @@ -2108,7 +2112,7 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call, SmallVector<DictionaryAttr> argAttrs; for (auto &llvmArgAttrs : llvmArgAttrsSet) argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder)); - callOp.setArgAttrsAttr(getArrayAttr(argAttrs)); + argAttrsOpInterface.setArgAttrsAttr(getArrayAttr(argAttrs)); } llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); @@ -2116,7 +2120,7 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call, return; SmallVector<DictionaryAttr, 1> resAttrs; resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder)); - callOp.setResAttrsAttr(getArrayAttr(resAttrs)); + argAttrsOpInterface.setResAttrsAttr(getArrayAttr(resAttrs)); } LogicalResult ModuleImport::processFunction(llvm::Function *func) { _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits