https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/123177
>From 137705661c184ea1530982c19163341933ab421e Mon Sep 17 00:00:00 2001 From: Jean Perier <jper...@nvidia.com> Date: Wed, 15 Jan 2025 09:09:53 -0800 Subject: [PATCH 1/4] [mlir][LLVM] add argument and result attributes to llvm.call --- llvm/include/llvm/IR/InstrTypes.h | 11 +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 4 +- .../include/mlir/Target/LLVMIR/ModuleImport.h | 8 ++- .../mlir/Target/LLVMIR/ModuleTranslation.h | 9 ++- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 67 +++++++++++++------ .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 21 ++++++ mlir/lib/Target/LLVMIR/ModuleImport.cpp | 35 ++++++++++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 52 ++++++++++---- mlir/test/Dialect/LLVMIR/invalid.mlir | 2 + mlir/test/Dialect/LLVMIR/roundtrip.mlir | 20 ++++++ .../LLVMIR/Import/call-argument-attributes.ll | 25 +++++++ .../LLVMIR/call-argument-attributes.mlir | 17 +++++ 12 files changed, 230 insertions(+), 41 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll create mode 100644 mlir/test/Target/LLVMIR/call-argument-attributes.mlir diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h index b8d9cc10292f4a..0e391325eebdce 100644 --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1490,6 +1490,11 @@ class CallBase : public Instruction { Attrs = Attrs.addRetAttribute(getContext(), Attr); } + /// Adds attributes to the return value. + void addRetAttrs(const AttrBuilder &B) { + Attrs = Attrs.addRetAttributes(getContext(), B); + } + /// Adds the attribute to the indicated argument void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) { assert(ArgNo < arg_size() && "Out of bounds"); @@ -1502,6 +1507,12 @@ class CallBase : public Instruction { Attrs = Attrs.addParamAttribute(getContext(), ArgNo, Attr); } + /// Adds attributes to the indicated argument + void addParamAttrs(unsigned ArgNo, const AttrBuilder &B) { + assert(ArgNo < arg_size() && "Out of bounds"); + Attrs = Attrs.addParamAttributes(getContext(), ArgNo, B); + } + /// removes the attribute from the list of attributes. void removeAttributeAtIndex(unsigned i, Attribute::AttrKind Kind) { Attrs = Attrs.removeAttributeAtIndex(getContext(), i, Kind); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index b2281536aa40b6..85f5c6cc8cca07 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -755,7 +755,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", VariadicOfVariadic<LLVM_Type, "op_bundle_sizes">:$op_bundle_operands, DenseI32ArrayAttr:$op_bundle_sizes, - OptionalAttr<ArrayAttr>:$op_bundle_tags); + OptionalAttr<ArrayAttr>:$op_bundle_tags, + OptionalAttr<DictArrayAttr>:$arg_attrs, + OptionalAttr<DictArrayAttr>:$res_attrs); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional<LLVM_Type>:$result); diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 33c9af7c6335a4..86e1d6a04cd096 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -326,14 +326,18 @@ class ModuleImport { SmallVectorImpl<Type> &types, SmallVectorImpl<Value> &operands, bool allowInlineAsm = false); - /// Converts the parameter attributes attached to `func` and adds them to the - /// `funcOp`. + /// Converts the parameter and result attributes attached to `func` and adds + /// them to the `funcOp`. void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp, OpBuilder &builder); /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding /// DictionaryAttr for the LLVM dialect. DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, OpBuilder &builder); + /// Converts the parameter and result attributes attached to `call` and adds + /// them to the `callOp`. + void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp, + OpBuilder &builder); /// Returns the builtin type equivalent to the given LLVM dialect type or /// nullptr if there is no equivalent. The returned type can be used to create /// an attribute for a GlobalOp or a ConstantOp. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 1b62437761ed9d..88fc17ca4fda24 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -228,6 +228,11 @@ class ModuleTranslation { /*recordInsertions=*/false); } + /// Translates parameter attributes of a call and adds them to the returned + /// AttrBuilder. Returns failure if any of the translations failed. + FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOp callOp, int argIdx, + DictionaryAttr paramAttrs); + /// Gets the named metadata in the LLVM IR module being constructed, creating /// it if it does not exist. llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name); @@ -346,8 +351,8 @@ class ModuleTranslation { convertDialectAttributes(Operation *op, ArrayRef<llvm::Instruction *> instructions); - /// Translates parameter attributes and adds them to the returned AttrBuilder. - /// Returns failure if any of the translations failed. + /// Translates parameter attributes of a function and adds them to the + /// returned AttrBuilder. Returns failure if any of the translations failed. FailureOr<llvm::AttrBuilder> convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index ef1e0222e05f06..6c4988bac7813e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1033,6 +1033,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1060,6 +1061,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1073,6 +1075,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1087,6 +1090,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1331,42 +1335,52 @@ void CallOp::print(OpAsmPrinter &p) { getVarCalleeTypeAttrName(), getCConvAttrName(), getOperandSegmentSizesAttrName(), getOpBundleSizesAttrName(), - getOpBundleTagsAttrName()}); + getOpBundleTagsAttrName(), getArgAttrsAttrName(), + getResAttrsAttrName()}); p << " : "; if (!isDirect) p << getOperand(0).getType() << ", "; // Reconstruct the function MLIR function type from operand and result types. - p.printFunctionalType(args.getTypes(), getResultTypes()); + call_interface_impl::printFunctionSignature( + p, *this, args.getTypes(), /*isVariadic=*/false, getResultTypes()); } /// Parses the type of a call operation and resolves the operands if the parsing /// succeeds. Returns failure otherwise. static ParseResult parseCallTypeAndResolveOperands( OpAsmParser &parser, OperationState &result, bool isDirect, - ArrayRef<OpAsmParser::UnresolvedOperand> operands) { + ArrayRef<OpAsmParser::UnresolvedOperand> operands, + SmallVectorImpl<DictionaryAttr> &argAttrs, + SmallVectorImpl<DictionaryAttr> &resultAttrs) { SMLoc trailingTypesLoc = parser.getCurrentLocation(); SmallVector<Type> types; - if (parser.parseColonTypeList(types)) + if (parser.parseColon()) return failure(); - - if (isDirect && types.size() != 1) - return parser.emitError(trailingTypesLoc, - "expected direct call to have 1 trailing type"); - if (!isDirect && types.size() != 2) - return parser.emitError(trailingTypesLoc, - "expected indirect call to have 2 trailing types"); - - auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val()); - if (!funcType) + if (!isDirect) { + types.emplace_back(); + if (parser.parseType(types.back())) + return failure(); + if (parser.parseOptionalComma()) + return parser.emitError( + trailingTypesLoc, "expected indirect call to have 2 trailing types"); + } + SmallVector<Type> argTypes; + SmallVector<Type> resTypes; + if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs, + resTypes, resultAttrs)) { + if (isDirect) + return parser.emitError(trailingTypesLoc, + "expected direct call to have 1 trailing types"); return parser.emitError(trailingTypesLoc, "expected trailing function type"); - if (funcType.getNumResults() > 1) + } + + if (resTypes.size() > 1) return parser.emitError(trailingTypesLoc, "expected function with 0 or 1 result"); - if (funcType.getNumResults() == 1 && - llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0))) + if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0])) return parser.emitError(trailingTypesLoc, "expected a non-void result type"); @@ -1374,12 +1388,12 @@ static ParseResult parseCallTypeAndResolveOperands( // indirect calls, while the types list is emtpy for direct calls. // Append the function input types to resolve the call operation // operands. - llvm::append_range(types, funcType.getInputs()); + llvm::append_range(types, argTypes); if (parser.resolveOperands(operands, types, parser.getNameLoc(), result.operands)) return failure(); - if (funcType.getNumResults() != 0) - result.addTypes(funcType.getResults()); + if (resTypes.size() != 0) + result.addTypes(resTypes); return success(); } @@ -1493,8 +1507,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); // Parse the trailing type list and resolve the operands. - if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands)) + SmallVector<DictionaryAttr> argAttrs; + SmallVector<DictionaryAttr> resultAttrs; + if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands, + argAttrs, resultAttrs)) return failure(); + call_interface_impl::addArgAndResultAttrs( + parser.getBuilder(), result, argAttrs, resultAttrs, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, opBundleOperandTypes, getOpBundleSizesAttrName(result.name))) @@ -1714,7 +1734,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); // Parse the trailing type list and resolve the function operands. - if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands)) + SmallVector<DictionaryAttr> argAttrs; + SmallVector<DictionaryAttr> resultAttrs; + if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands, + argAttrs, resultAttrs)) return failure(); if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, opBundleOperandTypes, diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 2084e527773ca8..52f42df60f0015 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -265,6 +265,27 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, if (callOp.getWillReturnAttr()) call->addFnAttr(llvm::Attribute::WillReturn); + if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) + for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { + if (auto argAttrs = llvm::cast<DictionaryAttr>(argAttrsAttr)) { + FailureOr<llvm::AttrBuilder> attrBuilder = + moduleTranslation.convertParameterAttrs(callOp, argIdx, argAttrs); + if (failed(attrBuilder)) + return failure(); + call->addParamAttrs(argIdx, *attrBuilder); + } + } + + ArrayAttr resAttrsArray = callOp.getResAttrsAttr(); + if (resAttrsArray && resAttrsArray.size() == 1) + if (auto resAttrs = llvm::cast<DictionaryAttr>(resAttrsArray[0])) { + FailureOr<llvm::AttrBuilder> attrBuilder = + moduleTranslation.convertParameterAttrs(callOp, -1, resAttrs); + if (failed(attrBuilder)) + return failure(); + call->addRetAttrs(*attrBuilder); + } + if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) { llvm::MemoryEffects memEffects = llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem, diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index eba86f06d09056..f65bf6584d51f2 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1641,6 +1641,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { if (callInst->hasFnAttr(llvm::Attribute::WillReturn)) callOp.setWillReturn(true); + // Handle parameter and result attributes. + convertParameterAttributes(callInst, callOp, builder); + llvm::MemoryEffects memEffects = callInst->getMemoryEffects(); ModRefInfo othermem = convertModRefInfoFromLLVM( memEffects.getModRef(llvm::MemoryEffects::Location::Other)); @@ -2084,6 +2087,38 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func, builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder))); } +void ModuleImport::convertParameterAttributes(llvm::CallBase *call, + CallOpInterface callOp, + OpBuilder &builder) { + auto llvmAttrs = call->getAttributes(); + SmallVector<llvm::AttributeSet> llvmArgAttrsSet; + bool anyArgAttrs = false; + for (size_t i = 0, e = call->arg_size(); i < e; ++i) { + llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i)); + if (llvmArgAttrsSet.back().hasAttributes()) + anyArgAttrs = true; + } + auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) { + SmallVector<Attribute> attrs; + for (auto &dict : dictAttrs) + attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); + return builder.getArrayAttr(attrs); + }; + if (anyArgAttrs) { + SmallVector<DictionaryAttr> argAttrs; + for (auto &llvmArgAttrs : llvmArgAttrsSet) + argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder)); + callOp.setArgAttrsAttr(getArrayAttr(argAttrs)); + } + + llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); + if (!llvmResAttr.hasAttributes()) + return; + SmallVector<DictionaryAttr, 1> resAttrs; + resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder)); + callOp.setResAttrsAttr(getArrayAttr(resAttrs)); +} + LogicalResult ModuleImport::processFunction(llvm::Function *func) { clearRegionState(); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 4367100e3aca68..b2d2c1cddca318 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1563,6 +1563,26 @@ static void convertFunctionKernelAttributes(LLVMFuncOp func, } } +static void convertParameterAttr(llvm::AttrBuilder &attrBuilder, + llvm::Attribute::AttrKind llvmKind, + NamedAttribute namedAttr, + ModuleTranslation &moduleTranslation) { + llvm::TypeSwitch<Attribute>(namedAttr.getValue()) + .Case<TypeAttr>([&](auto typeAttr) { + attrBuilder.addTypeAttr( + llvmKind, moduleTranslation.convertType(typeAttr.getValue())); + }) + .Case<IntegerAttr>([&](auto intAttr) { + attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt()); + }) + .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); }) + .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) { + attrBuilder.addConstantRangeAttr( + llvmKind, + llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper())); + }); +} + FailureOr<llvm::AttrBuilder> ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs) { @@ -1573,20 +1593,7 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, auto it = attrNameToKindMapping.find(namedAttr.getName()); if (it != attrNameToKindMapping.end()) { llvm::Attribute::AttrKind llvmKind = it->second; - - llvm::TypeSwitch<Attribute>(namedAttr.getValue()) - .Case<TypeAttr>([&](auto typeAttr) { - attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue())); - }) - .Case<IntegerAttr>([&](auto intAttr) { - attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt()); - }) - .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); }) - .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) { - attrBuilder.addConstantRangeAttr( - llvmKind, llvm::ConstantRange(rangeAttr.getLower(), - rangeAttr.getUpper())); - }); + convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this); } else if (namedAttr.getNameDialect()) { if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this))) return failure(); @@ -1596,6 +1603,23 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, return attrBuilder; } +FailureOr<llvm::AttrBuilder> +ModuleTranslation::convertParameterAttrs(CallOp, int argIdx, + DictionaryAttr paramAttrs) { + llvm::AttrBuilder attrBuilder(llvmModule->getContext()); + auto attrNameToKindMapping = getAttrNameToKindMapping(); + + for (auto namedAttr : paramAttrs) { + auto it = attrNameToKindMapping.find(namedAttr.getName()); + if (it != attrNameToKindMapping.end()) { + llvm::Attribute::AttrKind llvmKind = it->second; + convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this); + } + } + + return attrBuilder; +} + LogicalResult ModuleTranslation::convertFunctionSignatures() { // Declare all functions first because there may be function calls that form a // call graph with cycles, or global initializers that reference functions. diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 25806d9d0edd72..14cdcc06625c06 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -235,6 +235,7 @@ func.func @call_missing_ptr_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) { func.func private @standard_func_callee() func.func @call_missing_ptr_type(%arg : i8) { + // expected-error@+2 {{expected '('}} // expected-error@+1 {{expected direct call to have 1 trailing type}} llvm.call @standard_func_callee(%arg) : !llvm.ptr, (i8) -> (i8) llvm.return @@ -251,6 +252,7 @@ func.func @call_non_pointer_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) { // ----- func.func @call_non_function_type(%callee : !llvm.ptr, %arg : i8) { + // expected-error@+2 {{expected '('}} // expected-error@+1 {{expected trailing function type}} llvm.call %callee(%arg) : !llvm.ptr, !llvm.func<i8 (i8)> llvm.return diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 88660ce598f3c2..e565772f06b03c 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -941,3 +941,23 @@ llvm.func @test_assume_intr_with_opbundles(%arg0 : !llvm.ptr) { llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1 llvm.return } + +llvm.func @somefunc(i32, !llvm.ptr) + +// CHECK-LABEL: llvm.func @test_call_arg_attrs_direct( +// CHECK-SAME: %[[VAL_0:.*]]: i32, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) +llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) { + // CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + llvm.return +} + +// CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect( +// CHECK-SAME: %[[VAL_0:.*]]: i16, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr +llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 { + // CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + llvm.return %0 : i16 +} diff --git a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll new file mode 100644 index 00000000000000..8294579b48c63c --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll @@ -0,0 +1,25 @@ +; RUN: mlir-translate -import-llvm %s | FileCheck %s + +; CHECK-LABEL: llvm.func @somefunc(i32, !llvm.ptr) +declare void @somefunc(i32, ptr) + +; CHECK-LABEL: llvm.func @test_call_arg_attrs_direct( +; CHECK-SAME: %[[VAL_0:.*]]: i32, +; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) +llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) { +declare void @somefunc(i32, ptr) +; CHECK-LABEL: @test_call_arg_attrs_direct +define void @test_call_arg_attrs_direct(i32 %0, ptr %1) { + ; CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + call void @somefunc(i32 %0, ptr byval(i64) %1) + ret void +} + +; CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect( +; CHECK-SAME: %[[VAL_0:.*]]: i16, +; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr +define i16 @test_call_arg_attrs_indirect(i16 %0, ptr %1) { +; CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + %3 = tail call signext i16 %1(i16 noundef signext %0) + ret i16 %3 +} diff --git a/mlir/test/Target/LLVMIR/call-argument-attributes.mlir b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir new file mode 100644 index 00000000000000..89b1f29a68623b --- /dev/null +++ b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @somefunc(i32, !llvm.ptr) + +// CHECK-LABEL: define void @test_call_arg_attrs_direct +llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) { + // CHECK: call void @somefunc(i32 %{{.*}}, ptr byval(i64) %{{.*}}) + llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> () + llvm.return +} + +// CHECK-LABEL: define i16 @test_call_arg_attrs_indirec +llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 { + // CHECK: tail call signext i16 %{{.*}}(i16 noundef signext %{{.*}}) + %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) + llvm.return %0 : i16 +} >From 879b03de74daffe4f83a5c72f76fbeed495f73bf Mon Sep 17 00:00:00 2001 From: Jean Perier <jper...@nvidia.com> Date: Thu, 16 Jan 2025 04:42:18 -0800 Subject: [PATCH 2/4] remove bogus extra lines in new test --- mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll index 8294579b48c63c..2c86ca6b03125e 100644 --- a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll +++ b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll @@ -6,9 +6,6 @@ declare void @somefunc(i32, ptr) ; CHECK-LABEL: llvm.func @test_call_arg_attrs_direct( ; CHECK-SAME: %[[VAL_0:.*]]: i32, ; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) -llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) { -declare void @somefunc(i32, ptr) -; CHECK-LABEL: @test_call_arg_attrs_direct define void @test_call_arg_attrs_direct(i32 %0, ptr %1) { ; CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> () call void @somefunc(i32 %0, ptr byval(i64) %1) >From 854e43c16d73a3645cd224be2861836373079b1f Mon Sep 17 00:00:00 2001 From: Jean Perier <jper...@nvidia.com> Date: Thu, 16 Jan 2025 08:59:35 -0800 Subject: [PATCH 3/4] change inheritance level of new interface --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 1 + mlir/lib/Target/LLVMIR/ModuleImport.cpp | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 85f5c6cc8cca07..5c7a697107c237 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -699,6 +699,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<FastmathFlagsInterface>, DeclareOpInterfaceMethods<CallOpInterface>, + DeclareOpInterfaceMethods<ArgumentAttributesOpInterface>, DeclareOpInterfaceMethods<SymbolUserOpInterface>, DeclareOpInterfaceMethods<BranchWeightOpInterface>]> { let summary = "Call to an LLVM function."; diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index f65bf6584d51f2..f51b577c255660 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -2090,6 +2090,10 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func, void ModuleImport::convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp, OpBuilder &builder) { + auto argAttrsOpInterface = + dyn_cast<ArgumentAttributesOpInterface>(callOp.getOperation()); + if (!argAttrsOpInterface) + return; auto llvmAttrs = call->getAttributes(); SmallVector<llvm::AttributeSet> llvmArgAttrsSet; bool anyArgAttrs = false; @@ -2108,7 +2112,7 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call, SmallVector<DictionaryAttr> argAttrs; for (auto &llvmArgAttrs : llvmArgAttrsSet) argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder)); - callOp.setArgAttrsAttr(getArrayAttr(argAttrs)); + argAttrsOpInterface.setArgAttrsAttr(getArrayAttr(argAttrs)); } llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); @@ -2116,7 +2120,7 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call, return; SmallVector<DictionaryAttr, 1> resAttrs; resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder)); - callOp.setResAttrsAttr(getArrayAttr(resAttrs)); + argAttrsOpInterface.setResAttrsAttr(getArrayAttr(resAttrs)); } LogicalResult ModuleImport::processFunction(llvm::Function *func) { >From b588ff3dcdcd85414ce0ed6b274cce2ee3db2bd5 Mon Sep 17 00:00:00 2001 From: Jean Perier <jper...@nvidia.com> Date: Tue, 21 Jan 2025 02:02:00 -0800 Subject: [PATCH 4/4] adapt to ArgumentAttributesOpInterface iface removal --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 1 - mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 ++- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 8 ++------ 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index f6721819d04e6f..ee6e10efed4f16 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -701,7 +701,6 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<FastmathFlagsInterface>, DeclareOpInterfaceMethods<CallOpInterface>, - DeclareOpInterfaceMethods<ArgumentAttributesOpInterface>, DeclareOpInterfaceMethods<SymbolUserOpInterface>, DeclareOpInterfaceMethods<BranchWeightOpInterface>]> { let summary = "Call to an LLVM function."; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index a1a14f41e122b5..c10abdc24527e4 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1344,7 +1344,8 @@ void CallOp::print(OpAsmPrinter &p) { // Reconstruct the function MLIR function type from operand and result types. call_interface_impl::printFunctionSignature( - p, *this, args.getTypes(), /*isVariadic=*/false, getResultTypes()); + p, args.getTypes(), getArgAttrsAttr(), + /*isVariadic=*/false, getResultTypes(), getResAttrsAttr()); } /// Parses the type of a call operation and resolves the operands if the parsing diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index f51b577c255660..f65bf6584d51f2 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -2090,10 +2090,6 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func, void ModuleImport::convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp, OpBuilder &builder) { - auto argAttrsOpInterface = - dyn_cast<ArgumentAttributesOpInterface>(callOp.getOperation()); - if (!argAttrsOpInterface) - return; auto llvmAttrs = call->getAttributes(); SmallVector<llvm::AttributeSet> llvmArgAttrsSet; bool anyArgAttrs = false; @@ -2112,7 +2108,7 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call, SmallVector<DictionaryAttr> argAttrs; for (auto &llvmArgAttrs : llvmArgAttrsSet) argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder)); - argAttrsOpInterface.setArgAttrsAttr(getArrayAttr(argAttrs)); + callOp.setArgAttrsAttr(getArrayAttr(argAttrs)); } llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); @@ -2120,7 +2116,7 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call, return; SmallVector<DictionaryAttr, 1> resAttrs; resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder)); - argAttrsOpInterface.setResAttrsAttr(getArrayAttr(resAttrs)); + callOp.setResAttrsAttr(getArrayAttr(resAttrs)); } LogicalResult ModuleImport::processFunction(llvm::Function *func) { _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits