-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][vector] Add alignment attribute to vector operations. #152507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add alignment attribute to vector operations. #152507
Conversation
5a47fff
to
e2ad0f9
Compare
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Erick Ochoa Lopez (amd-eochoalo) ChangesFollowing #144344, #152207, #151690, this PR adds the alignment attribute to the following operations in the vector dialect:
Full diff: https://github.com/llvm/llvm-project/pull/152507.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b3b8afdd8b4c1..aae2051600251 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2054,7 +2054,9 @@ def Vector_GatherOp :
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
VectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$pass_thru)>,
+ AnyVectorOfNonZeroRank:$pass_thru,
+ ConfinedAttr<OptionalAttr<I64Attr>,
+ [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = [{
@@ -2111,6 +2113,31 @@ def Vector_GatherOp :
"`into` type($result)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "VectorType":$resultType,
+ "Value":$base,
+ "ValueRange":$indices,
+ "Value":$index_vec,
+ "Value":$mask,
+ "Value":$passthrough,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
+ return build($_builder, $_state, resultType, base, indices, index_vec, mask, passthrough,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>,
+ OpBuilder<(ins "TypeRange":$resultTypes,
+ "Value":$base,
+ "ValueRange":$indices,
+ "Value":$index_vec,
+ "Value":$mask,
+ "Value":$passthrough,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
+ return build($_builder, $_state, resultTypes, base, indices, index_vec, mask, passthrough,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>
+ ];
}
def Vector_ScatterOp :
@@ -2119,7 +2146,9 @@ def Vector_ScatterOp :
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
VectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$valueToStore)> {
+ AnyVectorOfNonZeroRank:$valueToStore,
+ ConfinedAttr<OptionalAttr<I64Attr>,
+ [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
let summary = [{
scatters elements from a vector into memory as defined by an index vector
@@ -2177,6 +2206,19 @@ def Vector_ScatterOp :
"type($index_vec) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "Value":$base,
+ "ValueRange":$indices,
+ "Value":$index_vec,
+ "Value":$mask,
+ "Value":$valueToStore,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
+ return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>
+ ];
}
def Vector_ExpandLoadOp :
@@ -2184,7 +2226,9 @@ def Vector_ExpandLoadOp :
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$pass_thru)>,
+ AnyVectorOfNonZeroRank:$pass_thru,
+ ConfinedAttr<OptionalAttr<I64Attr>,
+ [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
@@ -2246,6 +2290,29 @@ def Vector_ExpandLoadOp :
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "VectorType":$resultType,
+ "Value":$base,
+ "ValueRange":$indices,
+ "Value":$mask,
+ "Value":$passthrough,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
+ return build($_builder, $_state, resultType, base, indices, mask, passthrough,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>,
+ OpBuilder<(ins "TypeRange":$resultTypes,
+ "Value":$base,
+ "ValueRange":$indices,
+ "Value":$mask,
+ "Value":$passthrough,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
+ return build($_builder, $_state, resultTypes, base, indices, mask, passthrough,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>
+ ];
}
def Vector_CompressStoreOp :
@@ -2253,7 +2320,9 @@ def Vector_CompressStoreOp :
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$valueToStore)> {
+ AnyVectorOfNonZeroRank:$valueToStore,
+ ConfinedAttr<OptionalAttr<I64Attr>,
+ [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
let summary = "writes elements selectively from a vector as defined by a mask";
@@ -2312,6 +2381,17 @@ def Vector_CompressStoreOp :
"type($base) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$base,
+ "ValueRange":$indices,
+ "Value":$mask,
+ "Value":$valueToStore,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
+ return build($_builder, $_state, base, indices, valueToStore, mask,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>
+ ];
}
def Vector_ShapeCastOp :
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 211e16db85a94..68b07ec82aeb7 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1470,6 +1470,24 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
// -----
+func.func @gather_invalid_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
+ // expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ { alignment = -1 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func.func @gather_invalid_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
+ // expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ { alignment = 3 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
@@ -1531,6 +1549,24 @@ func.func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi
// -----
+func.func @scatter_invalid_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+ // expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.scatter %base[%c0][%indices], %mask, %value { alignment = -1 }
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
+func.func @scatter_invalid_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+ // expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.scatter %base[%c0][%indices], %mask, %value { alignment = 3 }
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.expandload' op base and result element type should match}}
@@ -1571,6 +1607,20 @@ func.func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>,
// -----
+func.func @expand_invalid_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
+ // expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = -1 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func.func @expand_invalid_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
+ // expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.compressstore' op base and valueToStore element type should match}}
@@ -1603,6 +1653,20 @@ func.func @compress_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>
// -----
+func.func @compress_invalid_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+ // expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.compressstore %base[%c0], %mask, %value { alignment = -1 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
+func.func @compress_invalid_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+ // expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.compressstore %base[%c0], %mask, %value { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
func.func @scan_reduction_dim_constraint(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>) -> vector<3xi32> {
// expected-error@+1 {{'vector.scan' op reduction dimension 5 has to be less than 2}}
%0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 5} :
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update op documentation and describe the semantics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@@ -1919,7 +1922,6 @@ def Vector_MaskedLoadOp : | |||
load operation. It must be a positive power of 2. The operation must access | |||
memory at an address aligned to this boundary. Violations may lead to | |||
architecture-specific faults or performance penalties. | |||
A value of 0 indicates no specific alignment requirement. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remind me what happens when alignment is not specified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally wanted thought about removing this line since I imagined that the constructors using the llvm::Maybe
align will be preferred, but I now believe that adding this line back makes more sense since there are other constructors as well and the actual value stored is an integer attribute. Thanks for pointing it out! 47db5b1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And just to double check - is 0
the default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0
is not the default value of the alignment
attribute in the operation. I think there's some ambiguity here regarding what "default" means. It could mean two things in my opinion:
- The default parameter to one of the constructors.
- The default value in the Operation's alignment field.
In the original PR (#144344) the default parameter to one of the constructors is indeed zero, but the attribute is optional and the attribute linked to the operation is actually a nullptr
.
CArg<"uint64_t", "0">:$alignment), [{
return build($_builder, $_state, memref, indices, nontemporal,
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
nullptr);
In PR #151690 the default parameter for these constructors was changed from a uint64_t
type to an llvm::MaybeAlign
, but keeps the attribute linked to the operation to be nullptr
when there is no alignment requirement. I.e., PR #15169
- The default parameter to these constructors is
llvm::MaybeAlign()
- When the default parameter is
llvm::MaybeAlign()
the integer attribute pointer isnullptr
.
Just to be complete, I believe in both cases the operation could have the field could be nullptr
or point to I64IntegerAttr(0)
to indicate no alignment requirement. For example, if the user used a different constructor passing all attributes in order.
I think having the documentation indicate that a value of zero indicates no specific alignment requirements is still correct as the Operation's alignment field is still an integer (when present) and it being zero would still signifies no specific alignment requirements.
I think we could also make the alignment attribute required by removing the OptionalAttr and then setting the alignment field point to I64IntegerAttr(0)
to remove the nullptr
and have solely I64IntegerAttr(0)
mean no specific alignment requirement. Happy to add changes if you think it is required :-).
I could also change the line to say that a value of llvm::MaybeAlign()
indicates no specific alignment requirements and values of llvm::Align(n)
for n bigger than zero to be alignment requirements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed explanation! Looks like things have become a bit complex recently 😅
I think having the documentation indicate that a value of zero indicates no specific alignment requirements is still correct as the Operation's alignment field is still an integer (when present) and it being zero would still signifies no specific alignment requirements.
I've just realised that we no longer support { alignment = 0 }
. Take this example:
func.func @load_with_alignment(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
%0 = vector.load %memref[%i, %j] { alignment = 0 } : memref<200x100xf32>, vector<8xf32>
return %0 : vector<8xf32>
}
Now:
$ bin/mlir-opt temp.mlir
temp.mlir:2:36: error: custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0
%0 = vector.load %memref[%i, %j] { alignment = 0 } : memref<200x100xf32>, vector<8xf32>
^
So, zero-alignment is no longer valid :) This makes sense to me - otherwise 0
was some magic value with some magic meaning.
To me, all of this calls for a few updates:
- Docs, i.e. this is no longer correct: "A value of 0 indicates no specific alignment requirement."
- Tests (in invalid.mlir, lets check
alignemt = 0
). - Predicates - can we update this with some alignment-specific predicate? (VectorOfNonZeroRankOf is a nice example of a self-descriptive constraint).
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, you are right. alignment = 0
was not supported in the Attribute due to the constraint.
Regarding next steps:
Docs, i.e. this is no longer correct: "A value of 0 indicates no specific alignment requirement."
This is fine, I'll remove this line.
Tests (in invalid.mlir, lets check alignemt = 0).
I think just updating the ones that are -1
to 0
would be enough.
Predicates - can we update this with some alignment-specific predicate? (VectorOfNonZeroRankOf is a nice example of a self-descriptive constraint).
Sure.
Sounds good! Thanks @banach-space
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@banach-space I am not sure about predicates yet. I think as long as alignment is modeled as an integer attribute, I64Attr
should be written in the tblgen file, and also if it is optional, OptionalAttr
should also be written in the tblgen file.
If the only requirements are positive and power of two, then they are already self-descriptive. One could also have something like:
[MemRead]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment);
// or also OptionalAttr<IntPositivePowerOf2<I64Attr>>:$alignment
def IntPositivePowerOf2 : AllAttrOf<[IntPositive, IntPowerOf2]>;
class IntValidAlignment<Attr attr>: ConfinedAttr<attr, [IntPositivePowerOf2]>;
What exactly do you propose as a predicate here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exactly do you propose as a predicate here?
Basically, something like:
def IntPositivePowerOf2 : AllAttrOf<[IntPositive, IntPowerOf2]>;
class IntValidAlignment<Attr attr>: ConfinedAttr<attr, [IntPositivePowerOf2]>;
You are right that IntPositive
and IntPowerOf2
are self descriptive on their own, but IMHO using IntValidAlignment
would create a global definition of what constitutes a valid alignment. This is a nice-to-have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was having some doubts because IntValidAlignment
would be a little bit of an indirection and now in order to know what a valid alignment is, one would need to look at this definition itself as opposed to having it directly on tablegen. I'll open the new PR with this predicate.
This is a squash of PR llvm#152507
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
@banach-space can we merge or do want to have one more pass? |
An optional `alignment` attribute allows to specify the byte alignment of the | ||
scatter operation. It must be a positive power of 2. The operation must access | ||
memory at an address aligned to this boundary. Violations may lead to | ||
architecture-specific faults or performance penalties. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The violation is strangely defined here, why isn't this specified as UB?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this particular case, there we some operations defined in this other PR #144344 which used this wording as documentation. I think changing the wording to be undefined behaviour is reasonable. Would something like the following be preferrable?
Violations will result in undefined behaviour and may lead to
architecture-specific faults or performance penalties.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
Violating this requirements triggers immediate undefined behavior
seems reasonable for the stores.
For the loads maybe we should instead use:
Violating this requirements will make the loaded value a poison value.
or something like that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think immediate UB is more suitable since using aligned loads on unaligned pointers is known to crash on some architectures
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, but that means we can't speculate an aligned load anymore right?
Maybe that's expected and OK though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can open a PR to change the wording. Thanks @joker-eph !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked LLVM and that seems like what we expect:
It is the responsibility of the code emitter to ensure that the alignment information is correct. Overestimating the alignment results in undefined behavior. Underestimating the alignment may produce less efficient code. An alignment of 1 is always safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may lead to architecture-specific faults or performance penalties
I personally prefer to avoid phrases like this unless we can provide a specific example. More generic terms like "UB" or "poison" tend to be more universal.
+1 Sorry, I was OOO. |
Following #144344, #152207, #151690, this PR adds the alignment attribute to the following operations in the vector dialect:
compressstore
expandload
vector.scatter
vector.gather