Issue 144005
Summary [mlir][tablegen] Underlying signless integer storage for Enum Attributes is handled incorrectly
Labels mlir
Assignees
Reporter 0xMihir
    # Description
If we create an integer enum attribute with a case that has the MSB set, it will fail validation due to an overflow. 

The following case illustrates the bug.

FIrst, we can define an enum in tablegen:
```tablegen
def I32Case5:  I32EnumAttrCase<"case5", 5>;
def I32Case10: I32EnumAttrCase<"case10", 10>;
def I32CaseSignedMaxPlusOne: I32EnumAttrCase<"caseSignedMaxPlusOne", 2147483648>;
def I32CaseUnsignedMax: I32EnumAttrCase<"caseUnsignedMax", 4294967295>;


def SomeI32Enum: I32EnumAttr<
  "SomeI32Enum", "", [I32Case5, I32Case10, 
                      I32CaseSignedMaxPlusOne, I32CaseUnsignedMax]>;

def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> {
 let arguments = (ins SomeI32Enum:$attr);
  let results = (outs I32:$val);
}
```

Then, using the defined Op, we can observe that the last two cases fail:

```mlir 
// CHECK-LABEL: func @allowed_cases_pass
func.func @allowed_cases_pass() {
  // CHECK: test.i32_enum_attr
  %0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32
  // CHECK: test.i32_enum_attr
  %1 = "test.i32_enum_attr"() {attr = 10: i32} : () -> i32
  // CHECK: test.i32_enum_attr
  %2 = "test.i32_enum_attr"() {attr = 2147483648: i32} : () -> i32
  // CHECK: test.i32_enum_attr
  %3 = "test.i32_enum_attr"() {attr = 4294967295: i32} : () -> i32
  return
}
```


This is because the underlying generator ([EnumsGen.cpp](https://github.com/llvm/llvm-project/blob/main/mlir/tools/mlir-tblgen/EnumsGen.cpp#L651)) and tablegen code ([EnumAttr.td](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/EnumAttr.td#L43)) use the deprecated `getInt` method, which sign extends the last bit. I don't think this is the best default behavior, but I don't think that we should change this old API. 


# Fix 
To fix this, we can return `ZExtValue` inside EnumsGen.cpp like so:

```diff
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 95767a29b9c3..def322a9d684 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -524,7 +524,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
 
   os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
 
- os << formatv("  return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
+  os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getValue().getZExtValue());\n",
 enumName);
 
   os << "}\n";
```

Similarly, for EnumAttr.td, we can use `APInt`'s `eq` method. 

```diff
diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 9fec28f03ec2..8d004f8b7b8c 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -34,7 +34,7 @@ class IntEnumAttrCaseBase<I intType, string sym, string strVal, int intVal> :
 EnumAttrCaseInfo<sym, intVal, strVal>,
     SignlessIntegerAttrBase<intType, "case " # strVal> {
   let predicate =
- CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() == " # intVal>;
+ CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().eq(APInt(" # intType.bitwidth # ", " # intVal # "))">;
 }
 
 // Cases of integer enum attributes with a specific type. By default, the string
```

Happy to open a PR for the above changes and additional test cases. 
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to