Skip to content

Commit 3f3282c

Browse files
authored
[AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. (#145395)
* 1-to-1 mapping wrapper op. * Direct lowering from AMDGPU wrapper to ROCDL intrinsics.
1 parent e2cc82b commit 3f3282c

File tree

7 files changed

+281
-2
lines changed

7 files changed

+281
-2
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,40 @@ def AMDGPU_GatherToLDSOp :
898898
let hasVerifier = 1;
899899
}
900900

901+
def AMDGPU_TransposeLoadOp :
902+
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
903+
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
904+
Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> {
905+
let summary = "MLIR wrapper for CDNA Transpose Load instructions";
906+
let description = [{
907+
The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
908+
The transpose load op represents a subgroup load from LDS memory,
909+
where the subgroup of threads collectively reads a matrix from the source
910+
memref, with each thread reading a vector of the matrix, and gets a transposed matrix
911+
in as the result. That is, each thread reads a vector of the col-major matrix at different
912+
indices, and the thread's read result is a vector of the corresponding row of the transposed
913+
matrix.
914+
915+
This op is a direct wrapper around the ROCDL `ds_read_tr` family intrinsics. Please refer
916+
to the CDNA4 ISA documentation for more details about its exact semantics.
917+
918+
Format example:
919+
```
920+
%0 = amdgpu.transpose_load %src[%srcIndices] : memref<128x256xf16> -> vector<4xf16>
921+
```
922+
Operands:
923+
* `$src`: LDS memref to read from.
924+
* `$srcIndices`: indices into `$src` to read from for this thread.
925+
* `$result`: target register this transpose load instruction will write to.
926+
927+
Note: Lowering is only supported on gfx950 and up.
928+
}];
929+
let assemblyFormat = [{
930+
$src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result)
931+
}];
932+
let hasVerifier = 1;
933+
}
934+
901935
def AMDGPU_ScaledMFMAOp :
902936
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
903937
Pure]>,

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,81 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
11001100
}
11011101
};
11021102

1103+
struct TransposeLoadOpLowering
1104+
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
1105+
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1106+
: ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1107+
1108+
Chipset chipset;
1109+
1110+
LogicalResult
1111+
matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1112+
ConversionPatternRewriter &rewriter) const override {
1113+
if (chipset != kGfx950)
1114+
return op.emitOpError("Non-gfx950 chipset not supported");
1115+
1116+
Location loc = op.getLoc();
1117+
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1118+
1119+
// Elements in subbyte memrefs are stored non-contiguously,
1120+
// reject if source is sub-byte memref. Use emulated memrefs instead.
1121+
size_t srcElementSize =
1122+
srcMemRefType.getElementType().getIntOrFloatBitWidth();
1123+
if (srcElementSize < 8)
1124+
return op.emitOpError("Expect source memref to have at least 8 bits "
1125+
"element size, got ")
1126+
<< srcElementSize;
1127+
1128+
auto resultType = cast<VectorType>(op.getResult().getType());
1129+
Value srcPtr =
1130+
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1131+
(adaptor.getSrcIndices()));
1132+
1133+
size_t numElements = resultType.getNumElements();
1134+
size_t elementTypeSize =
1135+
resultType.getElementType().getIntOrFloatBitWidth();
1136+
1137+
// ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1138+
// the element size is smaller than 16 bits.
1139+
Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1140+
rewriter.getIntegerType(32));
1141+
Type llvmResultType = typeConverter->convertType(resultType);
1142+
1143+
switch (elementTypeSize) {
1144+
case 4: {
1145+
assert(numElements == 16);
1146+
auto rocdlOp =
1147+
rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
1148+
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1149+
break;
1150+
}
1151+
case 6: {
1152+
assert(numElements == 16);
1153+
auto rocdlOp =
1154+
rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
1155+
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1156+
break;
1157+
}
1158+
case 8: {
1159+
assert(numElements == 8);
1160+
auto rocdlOp =
1161+
rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
1162+
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1163+
break;
1164+
}
1165+
case 16: {
1166+
assert(numElements == 4);
1167+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1168+
srcPtr);
1169+
break;
1170+
}
1171+
default:
1172+
return op.emitOpError("Unsupported element size for transpose load");
1173+
}
1174+
return success();
1175+
}
1176+
};
1177+
11031178
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
11041179
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
11051180
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1824,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
17491824
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
17501825
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
17511826
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1752-
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1753-
chipset);
1827+
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
1828+
TransposeLoadOpLowering>(converter, chipset);
17541829
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
17551830
}

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/IR/OpImplementation.h"
2525
#include "mlir/IR/PatternMatch.h"
2626
#include "mlir/IR/TypeUtilities.h"
27+
#include "llvm/ADT/DenseMap.h"
2728
#include "llvm/ADT/TypeSwitch.h"
2829

2930
#include <limits>
@@ -524,6 +525,39 @@ LogicalResult GatherToLDSOp::verify() {
524525
return success();
525526
}
526527

528+
LogicalResult TransposeLoadOp::verify() {
529+
MemRefType srcType = cast<MemRefType>(getSrc().getType());
530+
531+
if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
532+
return emitOpError("source memory address space must be Workgroup");
533+
534+
auto transferType = cast<VectorType>(getType());
535+
size_t numElements = transferType.getNumElements();
536+
size_t elementTypeSize =
537+
transferType.getElementType().getIntOrFloatBitWidth();
538+
539+
// ElementSize -> NumElements
540+
const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
541+
{4, 16},
542+
{6, 16},
543+
{8, 8},
544+
{16, 4},
545+
};
546+
547+
auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
548+
if (validNumElems == KValidLoadSizeMap.end()) {
549+
return emitOpError("Unsupported element type size for transpose load: ")
550+
<< elementTypeSize << " bits";
551+
}
552+
if (numElements != validNumElems->second) {
553+
return emitOpError(
554+
"Transferring type size mismatch: expected num of elements: ")
555+
<< validNumElems->second;
556+
}
557+
558+
return success();
559+
}
560+
527561
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
528562

529563
#define GET_ATTRDEF_CLASSES
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
2+
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx945 2>&1 | FileCheck %s --check-prefix=CHECK-OLD
3+
4+
// CHECK-LABEL: func @transpose_load_to_rocdl_4xf16
5+
func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, 3>) -> vector<4xf16> {
6+
// CHECK: rocdl.ds.read.tr16.b64
7+
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
8+
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, 3> -> vector<4xf16>
9+
return %0 : vector<4xf16>
10+
}
11+
12+
// -----
13+
14+
// CHECK-LABEL: func @transpose_load_to_rocdl_8xi8
15+
func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, 3>) -> vector<8xi8> {
16+
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr8.b64
17+
// CHECK-SAME: -> vector<2xi32>
18+
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<8xi8>
19+
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
20+
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, 3> -> vector<8xi8>
21+
return %0 : vector<8xi8>
22+
}
23+
24+
// -----
25+
26+
// CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi8
27+
func.func @transpose_load_to_rocdl_i4_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> {
28+
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr4.b64
29+
// CHECK-SAME: -> vector<2xi32>
30+
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<16xi4>
31+
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
32+
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi4>
33+
return %0 : vector<16xi4>
34+
}
35+
36+
// -----
37+
38+
// CHECK-LABEL: func @transpose_load_to_rocdl_i6_memrefxi8
39+
func.func @transpose_load_to_rocdl_i6_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi6> {
40+
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr6.b96
41+
// CHECK-SAME: -> vector<3xi32>
42+
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<3xi32> to vector<16xi6>
43+
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
44+
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi6>
45+
return %0 : vector<16xi6>
46+
}
47+
48+
// -----
49+
50+
// CHECK-LABEL: func @transpose_load_to_rocdl_i16_memrefxi8
51+
func.func @transpose_load_to_rocdl_i16_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<4xi16> {
52+
// CHECK: rocdl.ds.read.tr16.b64
53+
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
54+
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<4xi16>
55+
return %0 : vector<4xi16>
56+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 2>&1 | FileCheck %s
2+
3+
// -----
4+
5+
func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, 3>) -> vector<16xi4> {
6+
// CHECK: memref to have at least 8 bits element size, got 4
7+
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, 3> -> vector<16xi4>
8+
return %0 : vector<16xi4>
9+
}
10+
11+
// -----
12+
13+
func.func @transpose_load_to_rocdl_16xi6(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi6, 3>) -> vector<16xi6> {
14+
// CHECK: memref to have at least 8 bits element size, got 6
15+
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<16xi6>
16+
return %0 : vector<16xi6>
17+
}

mlir/test/Dialect/AMDGPU/invalid.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,59 @@ func.func @swizzle_scalable_vec(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
166166
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<[4]xf32>
167167
func.return %0 : vector<[4]xf32>
168168
}
169+
170+
// -----
171+
172+
func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
173+
// expected-error@+1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
174+
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
175+
func.return %0 : vector<4xf16>
176+
}
177+
178+
// -----
179+
180+
func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
181+
// expected-error@+1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
182+
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
183+
func.return %0 : vector<4xf16>
184+
}
185+
186+
// -----
187+
188+
func.func @transpose_load_elem_f32(%idx1 : index, %idx2 : index, %mem : memref<128x32xf32, 3>) -> vector<4xf32> {
189+
// expected-error@+1 {{'amdgpu.transpose_load' op Unsupported element type size for transpose load: 32 bits}}
190+
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf32, 3> -> vector<4xf32>
191+
func.return %0 : vector<4xf32>
192+
}
193+
194+
// -----
195+
196+
func.func @transpose_load_vector_size_f16(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<2xf16> {
197+
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 4}}
198+
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<2xf16>
199+
func.return %0 : vector<2xf16>
200+
}
201+
202+
// -----
203+
204+
func.func @transpose_load_vector_size_i4(%idx1 : index, %idx2 : index, %mem : memref<128x32xi4, 3>) -> vector<20xi4> {
205+
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}}
206+
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi4, 3> -> vector<20xi4>
207+
func.return %0 : vector<20xi4>
208+
}
209+
210+
// -----
211+
212+
func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi8, 3>) -> vector<20xi8> {
213+
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 8}}
214+
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<20xi8>
215+
func.return %0 : vector<20xi8>
216+
}
217+
218+
// -----
219+
220+
func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi6, 3>) -> vector<8xi6> {
221+
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}}
222+
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<8xi6>
223+
func.return %0 : vector<8xi6>
224+
}

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,10 @@ func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : v
486486
%0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
487487
func.return %0 : vector<16xf32>
488488
}
489+
490+
// CHECK-LABEL: func @transpose_load
491+
func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<4xf16> {
492+
// CHECK: amdgpu.transpose_load
493+
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<4xf16>
494+
func.return %0 : vector<4xf16>
495+
}

0 commit comments

Comments
 (0)