-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[mlir][vector] Update tests for collapse 3/n (nfc) #94906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -41,27 +41,27 @@ func.func @contiguous_inner_most_scalable_inner_dim(%in: memref<1x1x8x1xf32, str | |||||||||||||||||||||||||||||||||||||||
// Same as the top example within this split, but the trailing unit dim was | ||||||||||||||||||||||||||||||||||||||||
// replaced with a dyn dim - not supported | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
func.func @non_unit_trailing_dim(%in: memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{ | ||||||||||||||||||||||||||||||||||||||||
func.func @negative_dynamic_trailing_dim(%in: memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{ | ||||||||||||||||||||||||||||||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||||
%cst = arith.constant 0.0 : f32 | ||||||||||||||||||||||||||||||||||||||||
%0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32> | ||||||||||||||||||||||||||||||||||||||||
return %0 : vector<1x8x1xf32> | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
// CHECK-LABEL: func @non_unit_trailing_dim | ||||||||||||||||||||||||||||||||||||||||
// CHECK-LABEL: func @negative_dynamic_trailing_dim | ||||||||||||||||||||||||||||||||||||||||
// CHECK-NOT: memref.subview | ||||||||||||||||||||||||||||||||||||||||
// CHECK-NOT: vector.shape_cast | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
// Same as the top example within this split, but with a scalable unit dim in | ||||||||||||||||||||||||||||||||||||||||
// the output vector - not supported (scalable 1 is _not_ a unit dimension). | ||||||||||||||||||||||||||||||||||||||||
// Same as the top example within this split, but with a "scalable unit" dim in | ||||||||||||||||||||||||||||||||||||||||
// the output vector - not supported (scalable 1, [1], is _not_ a unit dimension). | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
func.func @negative_scalable_unit_dim(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x[1]xf32>{ | ||||||||||||||||||||||||||||||||||||||||
func.func @negative_scalable_one_trailing_dim(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x[1]xf32>{ | ||||||||||||||||||||||||||||||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||||
%cst = arith.constant 0.0 : f32 | ||||||||||||||||||||||||||||||||||||||||
%0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x[1]xf32> | ||||||||||||||||||||||||||||||||||||||||
return %0 : vector<1x8x[1]xf32> | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
// CHECK-LABEL: func @negative_scalable_unit_dim | ||||||||||||||||||||||||||||||||||||||||
// CHECK-LABEL: func @negative_scalable_one_trailing_dim | ||||||||||||||||||||||||||||||||||||||||
// CHECK-NOT: memref.subview | ||||||||||||||||||||||||||||||||||||||||
// CHECK-NOT: vector.shape_cast | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
|
@@ -254,14 +254,14 @@ func.func @negative_non_unit_inner_memref_dim(%arg0: memref<4x8xf32>) -> vector< | |||||||||||||||||||||||||||||||||||||||
// 2. vector.transfer_write | ||||||||||||||||||||||||||||||||||||||||
//----------------------------------------------------------------------------- | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) { | ||||||||||||||||||||||||||||||||||||||||
func.func @drop_two_inner_most_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) { | ||||||||||||||||||||||||||||||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||||
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] | ||||||||||||||||||||||||||||||||||||||||
{in_bounds = [true, true, true, true, true]} | ||||||||||||||||||||||||||||||||||||||||
: vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32> | ||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @drop_two_inner_most_dim_for_transfer_write | ||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @drop_two_inner_most_dim | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
|
@@ -272,16 +272,67 @@ func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1 | |||||||||||||||||||||||||||||||||||||||
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]] | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
// Same as the top example within this split, but with the inner vector | ||||||||||||||||||||||||||||||||||||||||
// dim scalable. Note that this example only makes sense when "16 = [16]" (i.e. | ||||||||||||||||||||||||||||||||||||||||
// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
func.func @drop_two_inner_most_dim_scalable_inner_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x[16]x1x1xf32>, %arg2: index) { | ||||||||||||||||||||||||||||||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||||
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] | ||||||||||||||||||||||||||||||||||||||||
{in_bounds = [true, true, true, true, true]} | ||||||||||||||||||||||||||||||||||||||||
: vector<1x16x[16]x1x1xf32>, memref<1x512x16x1x1xf32> | ||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @drop_two_inner_most_dim_scalable_inner_dim | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||||
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: memref<1x512x16x1x1xf32> to memref<1x512x16xf32, strided<[8192, 16, 1]>> | ||||||||||||||||||||||||||||||||||||||||
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x[16]x1x1xf32> to vector<1x16x[16]xf32> | ||||||||||||||||||||||||||||||||||||||||
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]] | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
// Same as the top example within this split, but the trailing unit dim was | ||||||||||||||||||||||||||||||||||||||||
// replaced with a dyn dim - not supported | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
func.func @negative_non_unit_trailing_dim(%arg0: memref<1x512x16x1x?xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) { | ||||||||||||||||||||||||||||||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||||
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] | ||||||||||||||||||||||||||||||||||||||||
{in_bounds = [true, true, true, true, true]} | ||||||||||||||||||||||||||||||||||||||||
: vector<1x16x16x1x1xf32>, memref<1x512x16x1x?xf32> | ||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @negative_non_unit_trailing_dim | ||||||||||||||||||||||||||||||||||||||||
// CHECK-NOT: memref.subview | ||||||||||||||||||||||||||||||||||||||||
// CHECK-NOT: vector.shape_cast | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
// Same as the top example within this split, but with a scalable unit dim in | ||||||||||||||||||||||||||||||||||||||||
// the output vector - not supported | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
func.func @negative_scalable_unit_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x[1]xf32>, %arg2: index) { | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "unit" is consistent with what's used in the pattern definition and the goal here is to remain consistent llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp Lines 1338 to 1354 in 77db8b0
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The pattern is removing unit dims yes, but There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I will update the comments to better document the behaviour. For consistency with the docs and the existing names, I am keeping "scalable unit" (happy to keep in quotes to highlight that it's a special 🌷 ) . As pointed out in my other comment, "scalable unit" vs "scalable one" is a matter of individual preference unless we formalise the definition. |
||||||||||||||||||||||||||||||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||||
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] | ||||||||||||||||||||||||||||||||||||||||
{in_bounds = [true, true, true, true, true]} | ||||||||||||||||||||||||||||||||||||||||
: vector<1x16x16x1x[1]xf32>, memref<1x512x16x1x1xf32> | ||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @negative_scalable_unit_dim | ||||||||||||||||||||||||||||||||||||||||
// CHECK-NOT: memref.subview | ||||||||||||||||||||||||||||||||||||||||
// CHECK-NOT: vector.shape_cast | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
// ----- | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) { | ||||||||||||||||||||||||||||||||||||||||
func.func @drop_inner_most_dim(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) { | ||||||||||||||||||||||||||||||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||||
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0] | ||||||||||||||||||||||||||||||||||||||||
{in_bounds = [true, true, true, true]} | ||||||||||||||||||||||||||||||||||||||||
: vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> | ||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @drop_inner_most_dim_for_transfer_write | ||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @drop_inner_most_dim | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
|
@@ -294,14 +345,14 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, | |||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
// ----- | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) { | ||||||||||||||||||||||||||||||||||||||||
func.func @outer_dyn_drop_inner_most_dim(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) { | ||||||||||||||||||||||||||||||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||||
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0] | ||||||||||||||||||||||||||||||||||||||||
{in_bounds = [true, true, true, true]} | ||||||||||||||||||||||||||||||||||||||||
: vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> | ||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write | ||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @outer_dyn_drop_inner_most_dim | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
|
@@ -325,30 +376,3 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o | |||||||||||||||||||||||||||||||||||||||
// The inner most unit dims can not be dropped if the strides are not ones. | ||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @non_unit_strides | ||||||||||||||||||||||||||||||||||||||||
// CHECK-NOT: memref.subview | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
// ----- | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %vec: vector<[4]x1xf32>) { | ||||||||||||||||||||||||||||||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||||
vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x1xf32>, memref<24x1xf32> | ||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @leading_scalable_dimension_transfer_write | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] | ||||||||||||||||||||||||||||||||||||||||
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>> | ||||||||||||||||||||||||||||||||||||||||
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32> | ||||||||||||||||||||||||||||||||||||||||
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>> | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
// ----- | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
// Negative test: [1] (scalable 1) is _not_ a unit dimension. | ||||||||||||||||||||||||||||||||||||||||
func.func @trailing_scalable_one_dim_transfer_write(%dest : memref<24x1xf32>, %vec: vector<4x[1]xf32>, %index: index) { | ||||||||||||||||||||||||||||||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||||
vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<4x[1]xf32>, memref<24x1xf32> | ||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||
// CHECK: func.func @trailing_scalable_one_dim_transfer_write | ||||||||||||||||||||||||||||||||||||||||
// CHECK-NOT: vector.shape_cast | ||||||||||||||||||||||||||||||||||||||||
// CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32> | ||||||||||||||||||||||||||||||||||||||||
// CHECK-NOT: vector.shape_cast |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no such thing as a scalable unit dim for AArch64
[1]
can contain up to 16 elements!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is target-agnostic - the limitations of any specific target are irrelevant here. If we want to allow/disallow sth, then we need formalise and justify it (i.e. document it). I don't think that's necessary.
As for "scalable unit" vs "scalable one" for
[1]
, I don't follow. The very presence of "scalable" in the name highlights that this is something special. By using "unit" we are consistent with the other tests and the terminology used in the pattern definition. What's wrong with "unit"?If we do advocate for inconsistency, then please formalise the difference between "scalable one" and "scalable unit" through documentation.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unit = an individual thing
But a scalable one dim is not an individual thing, depending on vscale it can be up to 16 elements for AArch64 (but for any scalable target it won't always be one individual element).
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I prefer "scalable one dim" as "scalable unit dim" feels like an oxymoron.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, and the meaning of "thing" is up for interpretation - yours is different to mine.
While there's no documented definition of "scalable one"/"scalable unit", this is all about our individual preferences ...
Anyway, I've renamed these tests and added comments in "VectorTransforms.cpp".