Skip to content

Commit c427acb

Browse files
committed
Update load/store verifier + tests
1 parent 72553d6 commit c427acb

File tree

3 files changed

+86
-14
lines changed

3 files changed

+86
-14
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
8181
// each dimension.
8282
static bool isArgShapesValid(ArrayRef<int64_t> descShape,
8383
ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
84-
if (descShape == valShape) {
85-
if (!sgMap)
86-
return true;
87-
88-
// this can be relaxed if necessary by supporting non-2d shapes distribution
89-
// until the constraints are defined this lives here instead of the tensor
90-
// descriptor type.
91-
return valShape.size() == sgMap.getWiLayout().size();
92-
}
84+
// Equal shapes with no distribution - no further verification needed.
85+
if (descShape == valShape && !sgMap)
86+
return true;
9387

88+
// Unknown distribution - cannot perform operation on partial shape.
9489
if (!sgMap)
9590
return false;
9691

97-
if (valShape.size() != descShape.size())
92+
// Invalid rank or mixed rank usage.
93+
size_t descRank = descShape.size();
94+
if (descRank > 2 || valShape.size() != descRank)
9895
return false;
9996

97+
// For 1D, SG map is guaranteed to be unit size in the outer dimension.
98+
// Only take the distribution over the innermost dimension for validation.
99+
ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
100+
SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
101+
if (descRank == 1)
102+
mapLayout = {wiLayout.back()};
103+
100104
for (const auto &[factor, dim, expected] :
101-
llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
105+
llvm::zip_equal(mapLayout, valShape, descShape)) {
102106
if (factor * dim != expected)
103107
return false;
104108
}

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
9797
gpu.return
9898
}
9999

100+
// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
101+
gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
102+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
103+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
104+
!xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
105+
// CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
106+
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
107+
gpu.return
108+
}
109+
100110
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
101111
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
102112
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
132142
gpu.return
133143
}
134144

145+
// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
146+
gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
147+
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
148+
%1 = arith.constant dense<1.0>: vector<2xf16>
149+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
150+
%2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
151+
!xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
152+
// CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
153+
xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
154+
gpu.return
155+
}
156+
135157
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
136158
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
137159
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
8282
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
8383
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
8484
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
85-
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
85+
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
86+
l2_hint = #xegpu.cache_hint<uncached>}>
87+
: !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
88+
-> vector<8x2xf32>
8689
return
8790
}
8891

8992
// -----
9093
func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
9194
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
92-
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
95+
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
9396
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
94-
%2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
97+
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
98+
l2_hint = #xegpu.cache_hint<uncached>}>
99+
: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
100+
-> vector<8xf32>
101+
return
102+
}
103+
104+
// -----
105+
func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
106+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
107+
!xegpu.tensor_desc<8x16xf32>
108+
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
109+
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
110+
l2_hint = #xegpu.cache_hint<uncached>}>
111+
: !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
95112
return
96113
}
97114

@@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
116133
return
117134
}
118135

136+
// -----
137+
func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
138+
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
139+
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
140+
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
141+
xegpu.store_nd %data, %1
142+
: vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
143+
return
144+
}
145+
146+
// -----
147+
func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
148+
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
149+
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
150+
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
151+
xegpu.store_nd %data, %1
152+
: vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
153+
return
154+
}
155+
156+
// -----
157+
func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
158+
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
159+
!xegpu.tensor_desc<8x16xf32>
160+
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
161+
xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
162+
return
163+
}
164+
119165
// -----
120166
func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
121167
%0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>

0 commit comments

Comments
 (0)