Skip to content

Commit 55c896e

Browse files
committed
Transposed 2d load.
Signed-off-by: Lu,Chengjun <chengjun.lu@intel.com>
1 parent c464166 commit 55c896e

File tree

2 files changed

+51
-19
lines changed

2 files changed

+51
-19
lines changed

python/test/unit/intel/test_block_store.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,6 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, transpose, device, tmp_
204204

205205
a = a.permute(1, 0).contiguous().permute(1, 0) if transpose else a
206206

207-
print("a:", a.shape, a.stride())
208-
print("x:", x.shape, x.stride())
209-
210207
kernel[(1, 1, 1)](a, x)
211208
assert torch.equal(a, x)
212209

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,11 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
373373

374374
// Returns the pitch (stride in bytes) of \p ptr.
375375
Value getPitch(ConversionPatternRewriter &rewriter, Value ptr,
376-
unsigned elemSizeInBits) const {
376+
unsigned elemSizeInBits, unsigned dim = 0) const {
377377
Location loc = ptr.getLoc();
378378
auto b = TritonLLVMOpBuilder(loc, rewriter);
379379

380-
int stride = getStride(ptr, 0);
380+
int stride = getStride(ptr, dim);
381381
// If the stride is 0, we assume a minimum pitch of 64 bytes.
382382
constexpr int MIN_PITCH = 64;
383383
if (stride == 0)
@@ -1884,17 +1884,6 @@ struct LoadOpToBlockIOConversion
18841884
// HW issue for vblock = 4
18851885
vBlocks = vBlocks == 4 ? 1 : vBlocks;
18861886

1887-
// TODO: use the axis info to general the handling for both regular pointer
1888-
// and block pointer.
1889-
const bool memoryRowMajor = isMemoryRowMajor(op);
1890-
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
1891-
const bool isTransposeRequired = contiguousDim != colDim;
1892-
1893-
if (isTransposeRequired) {
1894-
// TODO: support load column major data.
1895-
return failure();
1896-
}
1897-
18981887
Location loc = op.getLoc();
18991888
MLIRContext *ctx = op.getContext();
19001889
auto b = TritonLLVMOpBuilder(loc, rewriter);
@@ -2012,13 +2001,59 @@ struct LoadOpToBlockIOConversion
20122001
otherElems = unpackLLElements(loc, llOther, rewriter);
20132002
}
20142003

2004+
// TODO: use the axis info to general the handling for both regular pointer
2005+
// and block pointer.
2006+
const bool memoryRowMajor = isMemoryRowMajor(op);
2007+
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
2008+
const bool isTransposeRequired = contiguousDim != colDim;
2009+
2010+
if (isTransposeRequired) {
2011+
if (numPackedVals > 1)
2012+
return failure();
2013+
if (elemSizeInBits > 32)
2014+
return failure();
2015+
if (tileWidth > 32)
2016+
return failure(); // tileWidth is limited to 32 for transpose 2d load.
2017+
2018+
vBlocks = 1;
2019+
2020+
// use the d32 for transpose 2d load.
2021+
packedElemSizeInBits = 32;
2022+
numPackedVals = packedElemSizeInBits / elemSizeInBits;
2023+
tileHeight = std::min(tileHeight / numPackedVals, 8);
2024+
2025+
// transpose the width and height of the tile
2026+
std::swap(tileHeight, tileWidth);
2027+
if (tileHeight * tileWidth < threadsPerWarp)
2028+
return failure(); // The tile size is not large enough for IGC scalar
2029+
// backend vectorization.
2030+
// if (oneMatrixPerLoadForBT) {
2031+
// // Only load 1 operand per inst on row.
2032+
// numOperandsPer2DLoadM = 1;
2033+
// tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
2034+
// } else {
2035+
// // We can decompose the matrix returned by transposed large 2d load
2036+
// // when threads per warp < column size. Otherwise we have to load one
2037+
// // operand per inst.
2038+
// // Note: the tileHeight and numOperandsPer2DLoadM are the column size
2039+
// // now.
2040+
// numOperandsPer2DLoadM =
2041+
// (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
2042+
// }
2043+
// // The transpose 2d load only support 1 operand per inst on column.
2044+
// // (vBlocks = 1)
2045+
// numOperandsPer2DloadN = 1;
2046+
// // TODO: support load column major data.
2047+
// return failure();
2048+
}
2049+
20152050
baseWidth = b.i32_val(
20162051
std::max(64u, vBlocks * tileWidth * (packedElemSizeInBits / 8)));
20172052
// If the stride is 0, we want to load only the first row.
2018-
int stride = getStride(ptr, 0);
2053+
int stride = getStride(ptr, memoryRowMajor ? 0 : 1);
20192054
baseHeightInt = (stride == 0 ? 1 : tileHeight);
20202055
baseHeight = b.i32_val(baseHeightInt);
2021-
pitch = getPitch(rewriter, ptr, elemSizeInBits);
2056+
pitch = getPitch(rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1);
20222057
if (!pitch)
20232058
return failure();
20242059

@@ -2161,7 +2196,7 @@ struct LoadOpToBlockIOConversion
21612196
/*tile_width*/ tileWidth,
21622197
/*tile_height*/ tileHeight,
21632198
/*v_blocks*/ vBlocks,
2164-
/*transpose*/ false,
2199+
/*transpose*/ isTransposeRequired,
21652200
/*vnni_transform*/ useVNNIFormat);
21662201

21672202
// When strides[0] is 0, we only want to load the first row, so we

0 commit comments

Comments
 (0)