Skip to content

Commit f89c332

Browse files
committed
[MLIR] Add MemRefElementTypeInterface to gpu.mma_matrix
Add MemRefElementTypeInterface to gpu.mma_matrix and introduce an interface method that would allow analyses and cost models to work with it. This enables creation of memrefs of mma_matrix type, which in turn enables seamless fusion in the presence affine load/stores on such mma memrefs or forwarding of stores to loads out of the box.
1 parent f25185b commit f89c332

File tree

7 files changed

+64
-4
lines changed

7 files changed

+64
-4
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ struct MMAMatrixStorageType : public TypeStorage {
128128
/// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
129129
// TODO: consider moving this to ODS.
130130
class MMAMatrixType
131-
: public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
131+
: public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType,
132+
MemRefElementTypeInterface::Trait> {
132133
public:
133134
using Base::Base;
134135

@@ -163,6 +164,9 @@ class MMAMatrixType
163164
/// Get elementType of a single element.
164165
Type getElementType() const;
165166

167+
/// Implementation for MemRefElementTypeInterface.
168+
unsigned getAnalysisSizeInBytes() const;
169+
166170
/// The general form of operation this type supports is given by the equation
167171
/// C += A*B. This function returns which operand in the given equation is
168172
/// held by this type. String returned can be one of"AOp", "BOp" and "COp".

mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td

+4
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
6262
return $_get(memorySpace.getContext(), memorySpace);
6363
}]>
6464
];
65+
let extraClassDeclaration = [{
66+
/// Best effort size for analysis purposes.
67+
unsigned getAnalysisSizeInBytes() { return 8; }
68+
}];
6569
}
6670

6771
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

+13-3
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,20 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
7474
For example, scalar values such as integers can implement this interface,
7575
but indicator types such as `void` or `unit` should not.
7676

77-
The interface currently has no methods and is used by types to opt into
78-
being memref elements. This may change in the future, in particular to
79-
require types to provide their size or alignment given a data layout.
77+
The interface currently has one method and is mainly used by types to opt
78+
into being memref elements. This may change in the future, in particular to
79+
require types to provide actual size or alignment given a data layout.
8080
}];
81+
82+
let methods = [
83+
InterfaceMethod<[{
84+
Returns the size of the element type in bytes for purposes such as
85+
analysis. Such a size is meant to be used in analysis costs models as a
86+
best effort in the absence of data layout, as opposed to for
87+
target-specific lowering which would require a data layout.
88+
}],
89+
"unsigned", "getAnalysisSizeInBytes">,
90+
];
8191
}
8292

8393
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,9 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
13411341
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
13421342
else
13431343
return std::nullopt;
1344+
} else if (auto memrefEltType = dyn_cast<MemRefElementTypeInterface>(
1345+
memRefType.getElementType())) {
1346+
sizeInBits = memrefEltType.getAnalysisSizeInBytes() * 8;
13441347
} else {
13451348
return std::nullopt;
13461349
}

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
149149
elementType.isInteger(32);
150150
}
151151

152+
unsigned MMAMatrixType::getAnalysisSizeInBytes() const {
153+
// The underlying element type is expected to always be int or float and
154+
// typically divisible by 8 bits.
155+
return ShapedType::getNumElements(getShape()) *
156+
llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
157+
}
158+
152159
LogicalResult
153160
MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
154161
ArrayRef<int64_t> shape, Type elementType,

mlir/test/Dialect/Affine/loop-fusion-4.mlir

+28
Original file line numberDiff line numberDiff line change
@@ -666,3 +666,31 @@ func.func @unrolled(%arg0: memref<2x4xf32>, %arg1: memref<1x2x4xf32>) {
666666
// PRODUCER-CONSUMER-MAXIMAL: affine.load %{{.*}}[0, %{{.*}}, %{{.*}}]
667667
return
668668
}
669+
670+
// Test for fusion of affine load/store on memrefs of MMA type.
671+
672+
// PRODUCER-CONSUMER-LABEL: func @gpu_mma_cast
673+
func.func @gpu_mma_cast(%a: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %b: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %c: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>) {
674+
affine.for %i = 0 to 8 {
675+
affine.for %j = 0 to 4 {
676+
%v = affine.load %a[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
677+
affine.store %v, %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
678+
}
679+
}
680+
681+
affine.for %i = 0 to 8 {
682+
affine.for %j = 0 to 4 {
683+
%v = affine.load %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
684+
affine.store %v, %c[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
685+
}
686+
}
687+
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 8 {
688+
// PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 4 {
689+
// PRODUCER-CONSUMER-NEXT: affine.load
690+
// PRODUCER-CONSUMER-NEXT: affine.store
691+
// PRODUCER-CONSUMER-NEXT: affine.load
692+
// PRODUCER-CONSUMER-NEXT: affine.store
693+
694+
return
695+
// PRODUCER-CONSUMER: return
696+
}

mlir/test/lib/Dialect/Test/TestTypeDefs.td

+4
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
169169
def TestMemRefElementType : Test_Type<"TestMemRefElementType",
170170
[MemRefElementTypeInterface]> {
171171
let mnemonic = "memref_element";
172+
173+
let extraClassDeclaration = [{
174+
unsigned getAnalysisSizeInBytes() const { return 1; }
175+
}];
172176
}
173177

174178
def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;

0 commit comments

Comments
 (0)