Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Update APInt construction to correctly set isSigned/implicitTrunc #110466

Merged
merged 2 commits into from
Oct 14, 2024

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Sep 30, 2024

This fixes all the places in MLIR that hit the new assertion added in #106524, in preparation for enabling it by default. That is, cases where the value passed to the APInt constructor is not an N-bit signed/unsigned integer, where N is the bit width and signedness is determined by the isSigned flag.

The fixes either set the correct value for isSigned, or set the implicitTrunc flag to retain the old behavior. I've left TODOs for the latter case in some places, where I think that it may be worthwhile to stop doing implicit truncation in the future.

Note that the assertion is currently still disabled by default, so this patch is mostly NFC.

This is just the MLIR changes split off from #80309.

This fixes all the places in MLIR that hit the new assertion added
in llvm#106524, in preparation for enabling it by default. That is,
cases where the value passed to the APInt constructor is not an N-bit
signed/unsigned integer, where N is the bit width and signedness is
determined by the isSigned flag.

The fixes either set the correct value for isSigned, or set the
implicitTrunc flag to retain the old behavior. I've left TODOs
for the latter case in some places, where I think that it may be
worthwhile to stop doing implicit truncation in the future.

Note that the assertion is currently still disabled by default, so
this patch is mostly NFC.
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 30, 2024

@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-ods

Author: Nikita Popov (nikic)

Changes

This fixes all the places in MLIR that hit the new assertion added in #106524, in preparation for enabling it by default. That is, cases where the value passed to the APInt constructor is not an N-bit signed/unsigned integer, where N is the bit width and signedness is determined by the isSigned flag.

The fixes either set the correct value for isSigned, or set the implicitTrunc flag to retain the old behavior. I've left TODOs for the latter case in some places, where I think that it may be worthwhile to stop doing implicit truncation in the future.

Note that the assertion is currently still disabled by default, so this patch is mostly NFC.

This is just the MLIR changes split off from #80309.


Full diff: https://github.com/llvm/llvm-project/pull/110466.diff

9 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinAttributes.td (+3-1)
  • (modified) mlir/include/mlir/IR/OpImplementation.h (+2-1)
  • (modified) mlir/lib/Conversion/TosaToArith/TosaToArith.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+1-1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+1-1)
  • (modified) mlir/lib/IR/Builders.cpp (+12-4)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+4-2)
  • (modified) mlir/unittests/Dialect/SPIRV/SerializationTest.cpp (+1-1)
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index f0d41754001400..530ba7d2f11e5c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -701,8 +701,10 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
         return $_get(type.getContext(), type, apValue);
       }
 
+      // TODO: Avoid implicit trunc?
       IntegerType intTy = ::llvm::cast<IntegerType>(type);
-      APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger());
+      APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger(),
+                    /*implicitTrunc=*/true);
       return $_get(type.getContext(), type, apValue);
     }]>
   ];
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index e2472eea8a3714..606a56c7fd55b5 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -749,7 +749,8 @@ class AsmParser {
     // zero for non-negated integers.
     result =
         (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
-    if (APInt(uintResult.getBitWidth(), result) != uintResult)
+    if (APInt(uintResult.getBitWidth(), result, /*isSigned=*/true,
+              /*implicitTrunc=*/true) != uintResult)
       return emitError(loc, "integer value too large");
     return success();
   }
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 50e57682a2dc8d..593dbaa6c6545a 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -43,7 +43,7 @@ Type matchContainerType(Type element, Type container) {
 TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
     Type eTy = shapedTy.getElementType();
-    APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
+    APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
     return DenseIntElementsAttr::get(shapedTy, valueInt);
   }
 
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 98b429de1fd85c..edd7f607f24f4d 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -528,7 +528,7 @@ static ParseResult parseSwitchOpCases(
     int64_t value = 0;
     if (failed(parser.parseInteger(value)))
       return failure();
-    values.push_back(APInt(bitWidth, value));
+    values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
 
     Block *destination;
     SmallVector<OpAsmParser::UnresolvedOperand> operands;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0561c364c7d591..00f81db1dd795c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -598,7 +598,7 @@ static ParseResult parseSwitchOpCases(
     int64_t value = 0;
     if (failed(parser.parseInteger(value)))
       return failure();
-    values.push_back(APInt(bitWidth, value));
+    values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
 
     Block *destination;
     SmallVector<OpAsmParser::UnresolvedOperand> operands;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 90bf5df67b03ba..a5f987206db11c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1073,7 +1073,7 @@ static ParseResult parseMembersIndex(OpAsmParser &parser,
     if (parser.parseInteger(value))
       return failure();
     shapeTmp++;
-    values.push_back(APInt(32, value));
+    values.push_back(APInt(32, value, /*isSigned=*/true));
     return success();
   };
 
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7aed415343e551..7a7f360c53eb2f 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -234,7 +234,10 @@ DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
 }
 
 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
-  return IntegerAttr::get(getIntegerType(32), APInt(32, value));
+  // The APInt always uses isSigned=true here because we accept the value
+  // as int32_t.
+  return IntegerAttr::get(getIntegerType(32),
+                          APInt(32, value, /*isSigned=*/true));
 }
 
 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
@@ -252,14 +255,19 @@ IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
 }
 
 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
-  return IntegerAttr::get(getIntegerType(8), APInt(8, value));
+  // The APInt always uses isSigned=true here because we accept the value
+  // as int8_t.
+  return IntegerAttr::get(getIntegerType(8),
+                          APInt(8, value, /*isSigned=*/true));
 }
 
 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
   if (type.isIndex())
     return IntegerAttr::get(type, APInt(64, value));
-  return IntegerAttr::get(
-      type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
+  // TODO: Avoid implicit trunc?
+  return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value,
+                                      type.isSignedInteger(),
+                                      /*implicitTrunc=*/true));
 }
 
 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 38293f7106a05a..236e0ec100a0de 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1284,9 +1284,11 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
         uint32_t word1;
         uint32_t word2;
       } words = {operands[2], operands[3]};
-      value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
+      value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true,
+                    /*implicitTrunc=*/true);
     } else if (bitwidth <= 32) {
-      value = APInt(bitwidth, operands[2], /*isSigned=*/true);
+      value = APInt(bitwidth, operands[2], /*isSigned=*/true,
+                    /*implicitTrunc=*/true);
     }
 
     auto attr = opBuilder.getIntegerAttr(intType, value);
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 9d2f690ed898af..ef89c1645d373f 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -176,7 +176,7 @@ TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
       IntegerType::get(&context, 16, IntegerType::Signless);
   auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
   // Check the bit extension of same value under different signedness semantics.
-  APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
+  APInt signlessIntConstVal(signlessInt16Type.getWidth(), 0xffff,
                             signlessInt16Type.getSignedness());
   APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
                           signedInt16Type.getSignedness());

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 30, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Nikita Popov (nikic)

Changes

This fixes all the places in MLIR that hit the new assertion added in #106524, in preparation for enabling it by default. That is, cases where the value passed to the APInt constructor is not an N-bit signed/unsigned integer, where N is the bit width and signedness is determined by the isSigned flag.

The fixes either set the correct value for isSigned, or set the implicitTrunc flag to retain the old behavior. I've left TODOs for the latter case in some places, where I think that it may be worthwhile to stop doing implicit truncation in the future.

Note that the assertion is currently still disabled by default, so this patch is mostly NFC.

This is just the MLIR changes split off from #80309.


Full diff: https://github.com/llvm/llvm-project/pull/110466.diff

9 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinAttributes.td (+3-1)
  • (modified) mlir/include/mlir/IR/OpImplementation.h (+2-1)
  • (modified) mlir/lib/Conversion/TosaToArith/TosaToArith.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+1-1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+1-1)
  • (modified) mlir/lib/IR/Builders.cpp (+12-4)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+4-2)
  • (modified) mlir/unittests/Dialect/SPIRV/SerializationTest.cpp (+1-1)
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index f0d41754001400..530ba7d2f11e5c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -701,8 +701,10 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
         return $_get(type.getContext(), type, apValue);
       }
 
+      // TODO: Avoid implicit trunc?
       IntegerType intTy = ::llvm::cast<IntegerType>(type);
-      APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger());
+      APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger(),
+                    /*implicitTrunc=*/true);
       return $_get(type.getContext(), type, apValue);
     }]>
   ];
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index e2472eea8a3714..606a56c7fd55b5 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -749,7 +749,8 @@ class AsmParser {
     // zero for non-negated integers.
     result =
         (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
-    if (APInt(uintResult.getBitWidth(), result) != uintResult)
+    if (APInt(uintResult.getBitWidth(), result, /*isSigned=*/true,
+              /*implicitTrunc=*/true) != uintResult)
       return emitError(loc, "integer value too large");
     return success();
   }
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 50e57682a2dc8d..593dbaa6c6545a 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -43,7 +43,7 @@ Type matchContainerType(Type element, Type container) {
 TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
     Type eTy = shapedTy.getElementType();
-    APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
+    APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
     return DenseIntElementsAttr::get(shapedTy, valueInt);
   }
 
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 98b429de1fd85c..edd7f607f24f4d 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -528,7 +528,7 @@ static ParseResult parseSwitchOpCases(
     int64_t value = 0;
     if (failed(parser.parseInteger(value)))
       return failure();
-    values.push_back(APInt(bitWidth, value));
+    values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
 
     Block *destination;
     SmallVector<OpAsmParser::UnresolvedOperand> operands;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0561c364c7d591..00f81db1dd795c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -598,7 +598,7 @@ static ParseResult parseSwitchOpCases(
     int64_t value = 0;
     if (failed(parser.parseInteger(value)))
       return failure();
-    values.push_back(APInt(bitWidth, value));
+    values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
 
     Block *destination;
     SmallVector<OpAsmParser::UnresolvedOperand> operands;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 90bf5df67b03ba..a5f987206db11c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1073,7 +1073,7 @@ static ParseResult parseMembersIndex(OpAsmParser &parser,
     if (parser.parseInteger(value))
       return failure();
     shapeTmp++;
-    values.push_back(APInt(32, value));
+    values.push_back(APInt(32, value, /*isSigned=*/true));
     return success();
   };
 
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7aed415343e551..7a7f360c53eb2f 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -234,7 +234,10 @@ DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
 }
 
 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
-  return IntegerAttr::get(getIntegerType(32), APInt(32, value));
+  // The APInt always uses isSigned=true here because we accept the value
+  // as int32_t.
+  return IntegerAttr::get(getIntegerType(32),
+                          APInt(32, value, /*isSigned=*/true));
 }
 
 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
@@ -252,14 +255,19 @@ IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
 }
 
 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
-  return IntegerAttr::get(getIntegerType(8), APInt(8, value));
+  // The APInt always uses isSigned=true here because we accept the value
+  // as int8_t.
+  return IntegerAttr::get(getIntegerType(8),
+                          APInt(8, value, /*isSigned=*/true));
 }
 
 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
   if (type.isIndex())
     return IntegerAttr::get(type, APInt(64, value));
-  return IntegerAttr::get(
-      type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
+  // TODO: Avoid implicit trunc?
+  return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value,
+                                      type.isSignedInteger(),
+                                      /*implicitTrunc=*/true));
 }
 
 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 38293f7106a05a..236e0ec100a0de 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1284,9 +1284,11 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
         uint32_t word1;
         uint32_t word2;
       } words = {operands[2], operands[3]};
-      value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
+      value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true,
+                    /*implicitTrunc=*/true);
     } else if (bitwidth <= 32) {
-      value = APInt(bitwidth, operands[2], /*isSigned=*/true);
+      value = APInt(bitwidth, operands[2], /*isSigned=*/true,
+                    /*implicitTrunc=*/true);
     }
 
     auto attr = opBuilder.getIntegerAttr(intType, value);
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 9d2f690ed898af..ef89c1645d373f 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -176,7 +176,7 @@ TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
       IntegerType::get(&context, 16, IntegerType::Signless);
   auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
   // Check the bit extension of same value under different signedness semantics.
-  APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
+  APInt signlessIntConstVal(signlessInt16Type.getWidth(), 0xffff,
                             signlessInt16Type.getSignedness());
   APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
                           signedInt16Type.getSignedness());

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 30, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Nikita Popov (nikic)

Changes

This fixes all the places in MLIR that hit the new assertion added in #106524, in preparation for enabling it by default. That is, cases where the value passed to the APInt constructor is not an N-bit signed/unsigned integer, where N is the bit width and signedness is determined by the isSigned flag.

The fixes either set the correct value for isSigned, or set the implicitTrunc flag to retain the old behavior. I've left TODOs for the latter case in some places, where I think that it may be worthwhile to stop doing implicit truncation in the future.

Note that the assertion is currently still disabled by default, so this patch is mostly NFC.

This is just the MLIR changes split off from #80309.


Full diff: https://github.com/llvm/llvm-project/pull/110466.diff

9 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinAttributes.td (+3-1)
  • (modified) mlir/include/mlir/IR/OpImplementation.h (+2-1)
  • (modified) mlir/lib/Conversion/TosaToArith/TosaToArith.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+1-1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+1-1)
  • (modified) mlir/lib/IR/Builders.cpp (+12-4)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+4-2)
  • (modified) mlir/unittests/Dialect/SPIRV/SerializationTest.cpp (+1-1)
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index f0d41754001400..530ba7d2f11e5c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -701,8 +701,10 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
         return $_get(type.getContext(), type, apValue);
       }
 
+      // TODO: Avoid implicit trunc?
       IntegerType intTy = ::llvm::cast<IntegerType>(type);
-      APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger());
+      APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger(),
+                    /*implicitTrunc=*/true);
       return $_get(type.getContext(), type, apValue);
     }]>
   ];
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index e2472eea8a3714..606a56c7fd55b5 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -749,7 +749,8 @@ class AsmParser {
     // zero for non-negated integers.
     result =
         (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
-    if (APInt(uintResult.getBitWidth(), result) != uintResult)
+    if (APInt(uintResult.getBitWidth(), result, /*isSigned=*/true,
+              /*implicitTrunc=*/true) != uintResult)
       return emitError(loc, "integer value too large");
     return success();
   }
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 50e57682a2dc8d..593dbaa6c6545a 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -43,7 +43,7 @@ Type matchContainerType(Type element, Type container) {
 TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
     Type eTy = shapedTy.getElementType();
-    APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
+    APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
     return DenseIntElementsAttr::get(shapedTy, valueInt);
   }
 
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 98b429de1fd85c..edd7f607f24f4d 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -528,7 +528,7 @@ static ParseResult parseSwitchOpCases(
     int64_t value = 0;
     if (failed(parser.parseInteger(value)))
       return failure();
-    values.push_back(APInt(bitWidth, value));
+    values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
 
     Block *destination;
     SmallVector<OpAsmParser::UnresolvedOperand> operands;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0561c364c7d591..00f81db1dd795c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -598,7 +598,7 @@ static ParseResult parseSwitchOpCases(
     int64_t value = 0;
     if (failed(parser.parseInteger(value)))
       return failure();
-    values.push_back(APInt(bitWidth, value));
+    values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
 
     Block *destination;
     SmallVector<OpAsmParser::UnresolvedOperand> operands;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 90bf5df67b03ba..a5f987206db11c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1073,7 +1073,7 @@ static ParseResult parseMembersIndex(OpAsmParser &parser,
     if (parser.parseInteger(value))
       return failure();
     shapeTmp++;
-    values.push_back(APInt(32, value));
+    values.push_back(APInt(32, value, /*isSigned=*/true));
     return success();
   };
 
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7aed415343e551..7a7f360c53eb2f 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -234,7 +234,10 @@ DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
 }
 
 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
-  return IntegerAttr::get(getIntegerType(32), APInt(32, value));
+  // The APInt always uses isSigned=true here because we accept the value
+  // as int32_t.
+  return IntegerAttr::get(getIntegerType(32),
+                          APInt(32, value, /*isSigned=*/true));
 }
 
 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
@@ -252,14 +255,19 @@ IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
 }
 
 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
-  return IntegerAttr::get(getIntegerType(8), APInt(8, value));
+  // The APInt always uses isSigned=true here because we accept the value
+  // as int8_t.
+  return IntegerAttr::get(getIntegerType(8),
+                          APInt(8, value, /*isSigned=*/true));
 }
 
 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
   if (type.isIndex())
     return IntegerAttr::get(type, APInt(64, value));
-  return IntegerAttr::get(
-      type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
+  // TODO: Avoid implicit trunc?
+  return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value,
+                                      type.isSignedInteger(),
+                                      /*implicitTrunc=*/true));
 }
 
 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 38293f7106a05a..236e0ec100a0de 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1284,9 +1284,11 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
         uint32_t word1;
         uint32_t word2;
       } words = {operands[2], operands[3]};
-      value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
+      value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true,
+                    /*implicitTrunc=*/true);
     } else if (bitwidth <= 32) {
-      value = APInt(bitwidth, operands[2], /*isSigned=*/true);
+      value = APInt(bitwidth, operands[2], /*isSigned=*/true,
+                    /*implicitTrunc=*/true);
     }
 
     auto attr = opBuilder.getIntegerAttr(intType, value);
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 9d2f690ed898af..ef89c1645d373f 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -176,7 +176,7 @@ TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
       IntegerType::get(&context, 16, IntegerType::Signless);
   auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
   // Check the bit extension of same value under different signedness semantics.
-  APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
+  APInt signlessIntConstVal(signlessInt16Type.getWidth(), 0xffff,
                             signlessInt16Type.getSignedness());
   APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
                           signedInt16Type.getSignedness());

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPIR-V changes LGTM, I didn't check the other code

@nikic
Copy link
Contributor Author

nikic commented Oct 7, 2024

Ping for the non-SPIRV parts of this PR.

@joker-eph joker-eph changed the title [MLIR] Make compatible with APInt ctor assertion [MLIR] Update APInt construction to correctly set isSigned/implicitTrunc Oct 7, 2024
IntegerType intTy = ::llvm::cast<IntegerType>(type);
APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger());
APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger(),
/*implicitTrunc=*/true);
Copy link
Collaborator

@joker-eph joker-eph Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems strange we hit an issue here since we're passing the expected isSigned here, that would be a bug right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe I am missing how the assertion operates and when is implicitTrunc legit to use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, places marked with TODO probably shouldn't be using implicitTrunc. Here are the failures if we drop it: https://gist.github.com/nikic/d69e30cf1d28ef5988363dc11e203159

I'm guessing the main problem is

loc, n_type, IntegerAttr::get(n_type, -1));
trying to construct -1 of either an unsigned or signless integer type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the right way to avoid an implicit truncation with signless integers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not familiar with MLIR, but probably by using a ctor that accepts APInt? That way you can explicitly specify that the constant needs to be sign extended.

Or possibly the code here should be setting signed=true for signless integers, as parameter is int64_t so signed value should be the default assumption?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "proper" solution here is probably to either a) Only allow constructing signless integers from APInt or b) only allow constructing them from plain integers for bit widths <= 64 bit, as the sign distinction only becomes really problematic for larger bit widths.

But in any case, this is not something I want to touch in this PR...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you merged without further approvals or resolving completely this thread here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I took this as a discussion on how the TODO could be resolved in the future, not a blocking concern for this PR. Was there something you wanted me to change in this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// TODO: Avoid implicit trunc?

This isn't super clear or inciting to a fix. The TODO is a question? Such TODO looks to me like a PR "not ready to merge" because the actual "TODO" needs to be figured.

I would have tried encoded more of the discussion in the TODO actually, to make it more accessible for anyone seeing this TODO in terms of investigating what to do to fix it and not have to rediscover the investigation you done here.
I would think that a better TODO would look something like:

// TODO: We shouldn't use implicit trunc here, at the moment however treating signless integer creation .....

I'm actually not even sure how to finish the sentence, you didn't expand enough on the problem you saw with signless integer ("Tried that, and it causes other failures.").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I agree that the TODO is not super clear without some context. What I'd like to do is to add a reference to #112510 (which I just filed) to these TODOs, so there is a place that provides a more detailed explanation of the purpose of the TODO and how it may be resolved.

I don't want to add detailed analysis to individual TODO comments, as the point of leaving them is so that I don't have to analyze each one in detail. (For the record, I'm trying to get this assertion enabled for more than half a year already, and not tracking down everything to its leafs is part of the compromise to make this feasible at all.)

@nikic nikic merged commit e692af8 into llvm:main Oct 14, 2024
8 checks passed
@nikic nikic deleted the mlir-apint-assert branch October 14, 2024 13:01
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
…unc (llvm#110466)

This fixes all the places in MLIR that hit the new assertion added in
llvm#106524, in preparation for enabling it by default. That is, cases where
the value passed to the APInt constructor is not an N-bit
signed/unsigned integer, where N is the bit width and signedness is
determined by the isSigned flag.

The fixes either set the correct value for isSigned, or set the
implicitTrunc flag to retain the old behavior. I've left TODOs for the
latter case in some places, where I think that it may be worthwhile to
stop doing implicit truncation in the future.

Note that the assertion is currently still disabled by default, so this
patch is mostly NFC.

This is just the MLIR changes split off from
llvm#80309.
bricknerb pushed a commit to bricknerb/llvm-project that referenced this pull request Oct 17, 2024
…unc (llvm#110466)

This fixes all the places in MLIR that hit the new assertion added in
llvm#106524, in preparation for enabling it by default. That is, cases where
the value passed to the APInt constructor is not an N-bit
signed/unsigned integer, where N is the bit width and signedness is
determined by the isSigned flag.

The fixes either set the correct value for isSigned, or set the
implicitTrunc flag to retain the old behavior. I've left TODOs for the
latter case in some places, where I think that it may be worthwhile to
stop doing implicit truncation in the future.

Note that the assertion is currently still disabled by default, so this
patch is mostly NFC.

This is just the MLIR changes split off from
llvm#80309.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants