Skip to content

[mlir][xegpu] Add SIMT distribution patterns for UpdateNdOffset, PrefetchNd and GPU Index Ops. #136743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 78 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
39dcf9d
save work
charithaintc Mar 18, 2025
2058773
moving all ops to region working
charithaintc Mar 20, 2025
14233fa
moving all ops to region working
charithaintc Mar 20, 2025
f599873
save work
charithaintc Mar 20, 2025
220ed1f
save work
charithaintc Mar 21, 2025
2a8070f
save work
charithaintc Mar 21, 2025
4838b52
extend sg_map from subgroup to workgroup
chencha3 Mar 21, 2025
cb26979
format code
chencha3 Mar 21, 2025
273fc40
remove changes to prefetch op
chencha3 Mar 21, 2025
504d274
refine the doc for TensorDesc
chencha3 Mar 21, 2025
90e0704
save work
charithaintc Mar 21, 2025
3abe7cb
save work
charithaintc Mar 21, 2025
7c87319
Merge branch 'main' into xegpu_simt_dist
charithaintc Mar 21, 2025
596c953
update doc
chencha3 Mar 21, 2025
2065764
save work
charithaintc Mar 21, 2025
899439b
refine docs
chencha3 Mar 24, 2025
8636d15
refine docs
chencha3 Mar 24, 2025
0190418
refine util
chencha3 Mar 24, 2025
32f9272
refine convert_layout docs
chencha3 Mar 24, 2025
fe11c79
save work
charithaintc Mar 24, 2025
6e1ef3e
save work
charithaintc Mar 24, 2025
55c272c
save work
charithaintc Mar 25, 2025
ee56a3e
Merge branch 'gpu_dialect_changes' into xegpu_simt_dist
charithaintc Mar 25, 2025
1ffe5c8
save work
charithaintc Mar 26, 2025
e5521f9
save work before merging with Chao's PR
charithaintc Mar 27, 2025
350b581
Merge branch 'users/chencha3/xegpu/extend_sg_map' into xegpu_simt_dist
charithaintc Mar 27, 2025
5700c81
merge xegpu changes
charithaintc Mar 29, 2025
1619fcf
Merge branch 'main' into xegpu_simt_dist
charithaintc Mar 31, 2025
2334a97
refactor names
charithaintc Mar 31, 2025
9bddeb6
drop ScopeAttr and refine 1D layout support
chencha3 Apr 1, 2025
784ab38
refine isEvenDistributed
chencha3 Apr 1, 2025
28cf69e
format code
chencha3 Apr 1, 2025
930f1ab
Merge branch 'main' into extend_sg_map
chencha3 Apr 1, 2025
9ed0f87
fix format issue
chencha3 Apr 1, 2025
3b389bf
add 1D layout examples
chencha3 Apr 1, 2025
589d217
refactor names
charithaintc Apr 2, 2025
8b647c4
Merge branch 'users/chencha3/xegpu/extend_sg_map' into xegpu_simt_dist
charithaintc Apr 2, 2025
c6ccef2
refactor
charithaintc Apr 2, 2025
cbd0af0
refine LayoutAttr verifier
chencha3 Apr 4, 2025
3fb4fd4
add unit test
chencha3 Apr 4, 2025
77fdfef
remove dump file
chencha3 Apr 4, 2025
2751332
fix typo
chencha3 Apr 4, 2025
2a16d11
Merge branch 'main' into extend_sg_map
chencha3 Apr 4, 2025
d281a14
fix an error after mering with main
chencha3 Apr 4, 2025
fb28ce8
new line at the end of file
chencha3 Apr 7, 2025
f464662
update doc
chencha3 Apr 8, 2025
eea3c35
Merge branch 'main' into extend_sg_map
chencha3 Apr 8, 2025
7acc56d
Merge branch 'users/chencha3/xegpu/extend_sg_map' into xegpu_simt_dist
charithaintc Apr 8, 2025
270b498
Merge branch 'main' into xegpu_simt_dist
charithaintc Apr 9, 2025
2a1d373
Switch to 1D representation for SIMT
chencha3 Apr 10, 2025
2159119
refine verfier for load_nd and store_nd
chencha3 Apr 10, 2025
21f50c0
fix issues
charithaintc Apr 10, 2025
35f9cbe
Merge branch 'main' into xegpu_simt_dist
charithaintc Apr 10, 2025
c81b2e0
fix issues
charithaintc Apr 10, 2025
03bfe08
Merge branch 'users/chencha3/xegpu/xegpu_simt_2d_to_1d' into xegpu_si…
charithaintc Apr 11, 2025
2f2ec10
fix issues
charithaintc Apr 14, 2025
2ae3543
fix issues
charithaintc Apr 14, 2025
4c63916
fix issues
charithaintc Apr 14, 2025
2d9cfa3
fix build issue
charithaintc Apr 15, 2025
775d039
refine verifier for gather/scatter
chencha3 Apr 15, 2025
5520ce1
update comments
chencha3 Apr 15, 2025
6abc12a
fix tests
charithaintc Apr 15, 2025
379e186
fix
charithaintc Apr 16, 2025
aa7dbe1
fix
charithaintc Apr 16, 2025
dce6d2a
Merge branch 'users/chencha3/xegpu/xegpu_simt_2d_to_1d' into xegpu_si…
charithaintc Apr 16, 2025
ca5c7e9
fix comments
charithaintc Apr 16, 2025
ed3119c
fix comments
charithaintc Apr 16, 2025
c898de6
fix comments
charithaintc Apr 17, 2025
55be710
fix comments
charithaintc Apr 17, 2025
6e8888a
fix
charithaintc Apr 18, 2025
6ae7aa0
fix
charithaintc Apr 18, 2025
2896b34
Merge branch 'main' into xegpu_simt_dist
charithaintc Apr 18, 2025
68b1750
fix
charithaintc Apr 18, 2025
5f1798d
save work
charithaintc Apr 21, 2025
9391696
Merge branch 'main' into xegpu_simt_dist
charithaintc Apr 22, 2025
b3e6dc5
save work
charithaintc Apr 22, 2025
08d9e7b
Merge branch 'xegpu_simt_dist' into distribute_scf
charithaintc Apr 22, 2025
6447c63
add prefetch support
charithaintc Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,21 +322,22 @@ LogicalResult TensorDescType::verify(
// ---------------------------------------------------------------------
// Case 1: Regular loads/stores.
// ---------------------------------------------------------------------
// Distributed vector shape must be:
// [chunk_size / lane_data_size, lane_data_size]
// If the tensor descriptor shape is 1D, first dimension is ignored (set to 1).
// [lane_data_size]
// The following conditions must be met:
// * tensor_desc[0] == lane_layout[0]
// Distributed vector is a 1D vector with shape:
// [chunk_size]
// ---------------------------------------------------------------------
// Case 2: Block loads/stores
// ---------------------------------------------------------------------
// Additional definitions:
// tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
// n_distribution_units = tensor_size / distribution_unit_size
// fragment_size = n_distribution_units * lane_data_size
// Given above definitions, the following conditions must be met:
// * tensor_desc[0] % (lane_layout[0] × lane_data[0]) == 0
// * tensor_desc[1] % (lane_layout[1] × lane_data[1]) == 0
// Distributed vector shape must be:
// [n_distribution_units, lane_data_size]
// Distributed vector is a 1D vector with shape:
// [fragment_size]
FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
auto layout = llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
// It only works for subgroup level layout, which only has lane_layout
Expand Down
63 changes: 20 additions & 43 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ LogicalResult LoadNdOp::verify() {
}

// Check SIMD mode.
// adjusted tensor descriptor shape tracks the expected shape of the result.
auto tdescShape = getShapeOf(tdescTy);
auto valueShape = getShapeOf(valueTy);

Expand Down Expand Up @@ -547,38 +546,27 @@ LogicalResult LoadGatherOp::verify() {
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");

auto chunkSize = tdescTy.getChunkSize();
// for SIMT code, the value should be 1D vector with size of chunkSize.
if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
if (valueTy.getNumElements() != chunkSize) {

// a valid shape for SIMT case
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
if (tdescTy.getLayoutAttr())
return emitOpError()
<< "Result shape " << makeString(valueShape)
<< " is not a valid distribution for tensor descriptor "
<< tdescTy;
} else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
if (tdescTy.getLayoutAttr())
return emitOpError()
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
if (getTransposeAttr())
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
}
return success();
} else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
// for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
// it is a valid SIMT code if chunkSize happens to be the same as
// subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
if (getTransposeAttr())
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
return success();
}

// For SIMD code verification.
if (tdescTy.getRank() == 2) {
if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
if (!getTransposeAttr())
return emitOpError("load of rank-2 tensor has to be transposed.");
transpose({1, 0}, tdescShape);
}

if (tdescShape != valueShape)
return emitOpError() << "Result shape " << makeString(valueShape)
<< " is not consistent with tensor descriptor "
<< " is neither a valid distribution for SIMT nor "
"consistent with the tensor descriptor for SIMD "
<< tdescTy;
return success();
}
Expand Down Expand Up @@ -613,38 +601,27 @@ LogicalResult StoreScatterOp::verify() {
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");

auto chunkSize = tdescTy.getChunkSize();
// for SIMT code, the value should be 1D vector with size of chunkSize.
if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
if (valueTy.getNumElements() != chunkSize) {

// a valid shape for SIMT case
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
if (tdescTy.getLayoutAttr())
return emitOpError()
<< "Value shape " << makeString(valueShape)
<< " is not a valid distribution for tensor descriptor "
<< tdescTy;
} else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
if (tdescTy.getLayoutAttr())
return emitOpError()
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
if (getTransposeAttr())
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
}
return success();
} else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
// for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
// it is a valid SIMT code if chunkSize happens to be the same as
// subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
if (getTransposeAttr())
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
return success();
}

// for SIMD code verification.
if (tdescTy.getRank() == 2) {
if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
if (!getTransposeAttr())
return emitOpError("Store of a rank-2 tensor has to be transposed.");
transpose({1, 0}, tdescShape);
}

if (tdescShape != valueShape)
return emitOpError() << "Value shape " << makeString(valueShape)
<< " is not consistent with tensor descriptor "
<< " is neither a valid distribution for SIMT nor "
"consistent with the tensor descriptor for SIMD "
<< tdescTy;

return success();
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/XeGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ func.func @test_load_gather_simt_1(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
// expected-error@+1 {{Result shape [6] is not a valid distribution for tensor descriptor}}
// expected-error@+1 {{Result shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}}
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<6xf32>
return
}
Expand All @@ -266,7 +266,7 @@ func.func @test_store_scatter_simt_1(%src: ui64) {
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%val = arith.constant dense<2.9>: vector<6xf32>
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
// expected-error@+1 {{Value shape [6] is not a valid distribution for tensor descriptor}}
// expected-error@+1 {{Value shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}}
xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>}> : vector<6xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
return
}
Expand Down
Loading