Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Prevent non 2d shaped loads/stores to have an sg_map
  • Loading branch information
kurapov-peter committed Jan 22, 2025
commit 9213451ce4f9b36aa4794c2ca7adb9238041482c
11 changes: 9 additions & 2 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,15 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
// each dimension.
static bool isArgShapesValid(ArrayRef<int64_t> descShape,
ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
if (descShape == valShape)
return true;
if (descShape == valShape) {
if (!sgMap)
return true;

// this can be relaxed if necessary by supporting non-2d shapes distribution
// until the constraints are defined this lives here instead of the tensor
// descriptor type.
return valShape.size() == sgMap.getWiLayout().size();
}

if (!sgMap)
return false;
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/XeGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
return
}

// -----
func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
%2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
return
}

// -----
func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
%1 = arith.constant dense<1.0>: vector<24x32xf16>
Expand Down
Loading