Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Update APInt construction to correctly set isSigned/implicitTrunc #110466

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/IR/BuiltinAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,10 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
return $_get(type.getContext(), type, apValue);
}

// TODO: Avoid implicit trunc?
IntegerType intTy = ::llvm::cast<IntegerType>(type);
APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger());
APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger(),
/*implicitTrunc=*/true);
Copy link
Collaborator

@joker-eph joker-eph Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems strange we hit an issue here since we're passing the expected isSigned here, that would be a bug right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe I am missing how the assertion operates and when is implicitTrunc legit to use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, places marked with TODO probably shouldn't be using implicitTrunc. Here are the failures if we drop it: https://gist.github.com/nikic/d69e30cf1d28ef5988363dc11e203159

I'm guessing the main problem is

loc, n_type, IntegerAttr::get(n_type, -1));
trying to construct -1 of either an unsigned or signless integer type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the right way to avoid an implicit truncation with signless integers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not familiar with MLIR, but probably by using a ctor that accepts APInt? That way you can explicitly specify that the constant needs to be sign extended.

Or possibly the code here should be setting signed=true for signless integers, as parameter is int64_t so signed value should be the default assumption?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "proper" solution here is probably to either a) Only allow constructing signless integers from APInt or b) only allow constructing them from plain integers for bit widths <= 64 bit, as the sign distinction only becomes really problematic for larger bit widths.

But in any case, this is not something I want to touch in this PR...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you merged without further approvals or resolving completely this thread here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I took this as a discussion on how the TODO could be resolved in the future, not a blocking concern for this PR. Was there something you wanted me to change in this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// TODO: Avoid implicit trunc?

This isn't super clear or inciting to a fix. The TODO is a question? Such TODO looks to me like a PR "not ready to merge" because the actual "TODO" needs to be figured.

I would have tried encoded more of the discussion in the TODO actually, to make it more accessible for anyone seeing this TODO in terms of investigating what to do to fix it and not have to rediscover the investigation you done here.
I would think that a better TODO would look something like:

// TODO: We shouldn't use implicit trunc here, at the moment however treating signless integer creation .....

I'm actually not even sure how to finish the sentence, you didn't expand enough on the problem you saw with signless integer ("Tried that, and it causes other failures.").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I agree that the TODO is not super clear without some context. What I'd like to do is to add a reference to #112510 (which I just filed) to these TODOs, so there is a place that provides a more detailed explanation of the purpose of the TODO and how it may be resolved.

I don't want to add detailed analysis to individual TODO comments, as the point of leaving them is so that I don't have to analyze each one in detail. (For the record, I'm trying to get this assertion enabled for more than half a year already, and not tracking down everything to its leafs is part of the compromise to make this feasible at all.)

return $_get(type.getContext(), type, apValue);
}]>
];
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,8 @@ class AsmParser {
// zero for non-negated integers.
result =
(IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
if (APInt(uintResult.getBitWidth(), result) != uintResult)
if (APInt(uintResult.getBitWidth(), result, /*isSigned=*/true,
/*implicitTrunc=*/true) != uintResult)
return emitError(loc, "integer value too large");
return success();
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Type matchContainerType(Type element, Type container) {
TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
Type eTy = shapedTy.getElementType();
APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
return DenseIntElementsAttr::get(shapedTy, valueInt);
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ static ParseResult parseSwitchOpCases(
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value));
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));

Block *destination;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ static ParseResult parseSwitchOpCases(
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value));
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));

Block *destination;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ static ParseResult parseMembersIndex(OpAsmParser &parser,
if (parser.parseInteger(value))
return failure();
shapeTmp++;
values.push_back(APInt(32, value));
values.push_back(APInt(32, value, /*isSigned=*/true));
return success();
};

Expand Down
16 changes: 12 additions & 4 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,10 @@ DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
}

IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
return IntegerAttr::get(getIntegerType(32), APInt(32, value));
// The APInt always uses isSigned=true here because we accept the value
// as int32_t.
return IntegerAttr::get(getIntegerType(32),
APInt(32, value, /*isSigned=*/true));
}

IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
Expand All @@ -252,14 +255,19 @@ IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
}

IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
return IntegerAttr::get(getIntegerType(8), APInt(8, value));
// The APInt always uses isSigned=true here because we accept the value
// as int8_t.
return IntegerAttr::get(getIntegerType(8),
APInt(8, value, /*isSigned=*/true));
}

IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
if (type.isIndex())
return IntegerAttr::get(type, APInt(64, value));
return IntegerAttr::get(
type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
// TODO: Avoid implicit trunc?
return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value,
type.isSignedInteger(),
/*implicitTrunc=*/true));
}

IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,8 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
} words = {operands[2], operands[3]};
value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
} else if (bitwidth <= 32) {
value = APInt(bitwidth, operands[2], /*isSigned=*/true);
value = APInt(bitwidth, operands[2], /*isSigned=*/true,
/*implicitTrunc=*/true);
}

auto attr = opBuilder.getIntegerAttr(intType, value);
Expand Down
2 changes: 1 addition & 1 deletion mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
IntegerType::get(&context, 16, IntegerType::Signless);
auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
// Check the bit extension of same value under different signedness semantics.
APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
APInt signlessIntConstVal(signlessInt16Type.getWidth(), 0xffff,
signlessInt16Type.getSignedness());
APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
signedInt16Type.getSignedness());
Expand Down
Loading