Skip to content

Commit 3c668aa

Browse files
author
Longsheng Du
authored
[Linalgx] Deprecate all linalgx matmul ops (#307)
1 parent 7223eb0 commit 3c668aa

File tree

11 files changed

+126
-925
lines changed

11 files changed

+126
-925
lines changed

include/gc/Analysis/MatmulConfigAnalysis.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,14 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
109109
if (llvm::isa<linalg::ContractionOpInterface>(linalgOp.getOperation())) {
110110
return getContractionOpOperandDimType(linalgOp);
111111
} else if (linalgx::isGenericPackedMatmulOp(
112-
linalgOp.getOperation(), linalgx::PackingType::VNNI_MM2D) ||
113-
llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
112+
linalgOp.getOperation(), linalgx::PackingType::VNNI_MM2D)) {
114113
return SmallVector<SmallVector<DimType>>{
115114
SmallVector<DimType>{DimType::M, DimType::K},
116115
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
117116
DimType::K},
118117
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
119118
} else if (linalgx::isGenericPackedMatmulOp(
120-
linalgOp.getOperation(), linalgx::PackingType::VNNI_MM4D) ||
121-
llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) {
119+
linalgOp.getOperation(), linalgx::PackingType::VNNI_MM4D)) {
122120
return SmallVector<SmallVector<DimType>>{
123121
SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
124122
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,

include/gc/Dialect/Linalgx/LinalgxStructuredOps.td

Lines changed: 0 additions & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -104,212 +104,4 @@ def Linalgx_SigmoidOp : LinalgxStructuredBase_Op<"sigmoid",
104104
}];
105105
}
106106

107-
def Linalgx_Mm2DVnniOp
108-
: LinalgxStructuredBase_Op<"mm2d_vnni", [AttrSizedOperandSegments]> {
109-
let summary = "Transposed matmul with 2d input and vnni packed weights";
110-
let description = [{
111-
Supported format: A[M, K] * B[N0, K0, k, n, v] -> C[M, N], with:
112-
N = N0 * n
113-
K = K0 * k * v; v = (2, 4)
114-
}];
115-
let arguments = (ins
116-
Variadic<TensorOrMemref>:$inputs,
117-
Variadic<TensorOrMemref>:$outputs);
118-
let results = (outs Variadic<TensorOrMemref>:$results);
119-
let regions = (region AnyRegion:$region);
120-
121-
let skipDefaultBuilders = 1;
122-
let builders = [
123-
OpBuilder<
124-
(ins
125-
"TypeRange":$resultTensorTypes,
126-
"ValueRange":$inputs,
127-
"ValueRange":$outputs,
128-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
129-
[{
130-
buildStructuredOp($_builder, $_state, resultTensorTypes,
131-
inputs, outputs, attributes, Mm2DVnniOp::getRegionBuilder());
132-
}]>
133-
];
134-
135-
let hasCustomAssemblyFormat = 1;
136-
let hasFolder = 1;
137-
let hasVerifier = 1;
138-
139-
let extraClassDeclaration = structuredOpsBaseDecls # [{
140-
// Declare functions necessary for LinalgStructuredInterface.
141-
SmallVector<utils::IteratorType> getIteratorTypesArray();
142-
ArrayAttr getIndexingMaps();
143-
static unsigned getNumRegionArgs() { return 3; }
144-
std::string getLibraryCallName() {
145-
return "op_has_no_registered_library_name";
146-
}
147-
148-
// Implement functions necessary for DestinationStyleOpInterface.
149-
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
150-
151-
static void regionBuilder(ImplicitLocOpBuilder &b,
152-
Block &block, ArrayRef<NamedAttribute> attrs);
153-
static std::function<void(ImplicitLocOpBuilder &,
154-
Block &, ArrayRef<NamedAttribute>)>
155-
getRegionBuilder() {
156-
return regionBuilder;
157-
}
158-
}];
159-
}
160-
161-
def Linalgx_Mm4DVnniOp
162-
: LinalgxStructuredBase_Op<"mm4d_vnni", [AttrSizedOperandSegments]> {
163-
let summary = "Transposed matmul with 4d blocking input and vnni packed weights";
164-
let description = [{
165-
Supported format: A[M, K, m, k] * B[N, K, k0, n, v] -> C[M, N, m, n], with:
166-
k = k0 * v; v = (2, 4)
167-
}];
168-
let arguments = (ins
169-
Variadic<TensorOrMemref>:$inputs,
170-
Variadic<TensorOrMemref>:$outputs);
171-
let results = (outs Variadic<TensorOrMemref>:$results);
172-
let regions = (region AnyRegion:$region);
173-
174-
let skipDefaultBuilders = 1;
175-
let builders = [
176-
OpBuilder<
177-
(ins
178-
"TypeRange":$resultTensorTypes,
179-
"ValueRange":$inputs,
180-
"ValueRange":$outputs,
181-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
182-
[{
183-
buildStructuredOp($_builder, $_state, resultTensorTypes,
184-
inputs, outputs, attributes, Mm4DVnniOp::getRegionBuilder());
185-
}]>
186-
];
187-
188-
let hasCustomAssemblyFormat = 1;
189-
let hasFolder = 1;
190-
let hasVerifier = 1;
191-
192-
let extraClassDeclaration = structuredOpsBaseDecls # [{
193-
// Declare functions necessary for LinalgStructuredInterface.
194-
SmallVector<utils::IteratorType> getIteratorTypesArray();
195-
ArrayAttr getIndexingMaps();
196-
static unsigned getNumRegionArgs() { return 3; }
197-
std::string getLibraryCallName() {
198-
return "op_has_no_registered_library_name";
199-
}
200-
201-
// Implement functions necessary for DestinationStyleOpInterface.
202-
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
203-
204-
static void regionBuilder(ImplicitLocOpBuilder &b,
205-
Block &block, ArrayRef<NamedAttribute> attrs);
206-
static std::function<void(ImplicitLocOpBuilder &,
207-
Block &, ArrayRef<NamedAttribute>)>
208-
getRegionBuilder() {
209-
return regionBuilder;
210-
}
211-
}];
212-
}
213-
214-
def Linalgx_BatchReduceMatmulVnniOp
215-
: LinalgxStructuredBase_Op<"batch_reduce_matmul_vnni", [AttrSizedOperandSegments]> {
216-
let summary = "Batch reduced matmul with 3d batch input and vnni packed weights";
217-
let description = [{
218-
Supported format: A[B, M, K] * B[B, k, N, v] -> C[M, N], with:
219-
K = k * v; v = (2, 4)
220-
}];
221-
let arguments = (ins
222-
Variadic<TensorOrMemref>:$inputs,
223-
Variadic<TensorOrMemref>:$outputs);
224-
let results = (outs Variadic<TensorOrMemref>:$results);
225-
let regions = (region AnyRegion:$region);
226-
227-
let skipDefaultBuilders = 1;
228-
let builders = [
229-
OpBuilder<
230-
(ins
231-
"TypeRange":$resultTensorTypes,
232-
"ValueRange":$inputs,
233-
"ValueRange":$outputs,
234-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
235-
[{
236-
buildStructuredOp($_builder, $_state, resultTensorTypes,
237-
inputs, outputs, attributes, BatchReduceMatmulVnniOp::getRegionBuilder());
238-
}]>
239-
];
240-
241-
let hasCustomAssemblyFormat = 1;
242-
let hasFolder = 1;
243-
let hasVerifier = 1;
244-
245-
let extraClassDeclaration = structuredOpsBaseDecls # [{
246-
// Declare functions necessary for LinalgStructuredInterface.
247-
SmallVector<utils::IteratorType> getIteratorTypesArray();
248-
ArrayAttr getIndexingMaps();
249-
static unsigned getNumRegionArgs() { return 3; }
250-
std::string getLibraryCallName() {
251-
return "op_has_no_registered_library_name";
252-
}
253-
254-
// Implement functions necessary for DestinationStyleOpInterface.
255-
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
256-
257-
static void regionBuilder(ImplicitLocOpBuilder &b,
258-
Block &block, ArrayRef<NamedAttribute> attrs);
259-
static std::function<void(ImplicitLocOpBuilder &,
260-
Block &, ArrayRef<NamedAttribute>)>
261-
getRegionBuilder() {
262-
return regionBuilder;
263-
}
264-
}];
265-
}
266-
267-
def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul",
268-
[AttrSizedOperandSegments, LinalgContractionOpInterface]> {
269-
let summary = "Batch matmul with variable batch dims";
270-
let arguments = (ins
271-
Variadic<TensorOrMemref>:$inputs,
272-
Variadic<TensorOrMemref>:$outputs);
273-
let results = (outs Variadic<TensorOrMemref>:$results);
274-
let regions = (region AnyRegion:$region);
275-
276-
let skipDefaultBuilders = 1;
277-
let builders = [
278-
OpBuilder<
279-
(ins
280-
"TypeRange":$resultTensorTypes,
281-
"ValueRange":$inputs,
282-
"ValueRange":$outputs,
283-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
284-
[{
285-
buildStructuredOp($_builder, $_state, resultTensorTypes,
286-
inputs, outputs, attributes, MultiBatchMatmulOp::getRegionBuilder());
287-
}]>
288-
];
289-
290-
let hasCustomAssemblyFormat = 1;
291-
let hasFolder = 1;
292-
293-
let extraClassDeclaration = structuredOpsBaseDecls # [{
294-
// Declare functions necessary for LinalgStructuredInterface.
295-
SmallVector<utils::IteratorType> getIteratorTypesArray();
296-
ArrayAttr getIndexingMaps();
297-
static unsigned getNumRegionArgs() { return 3; }
298-
std::string getLibraryCallName() {
299-
return "op_has_no_registered_library_name";
300-
}
301-
302-
// Implement functions necessary for DestinationStyleOpInterface.
303-
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
304-
305-
static void regionBuilder(ImplicitLocOpBuilder &b,
306-
Block &block, ArrayRef<NamedAttribute> attrs);
307-
static std::function<void(ImplicitLocOpBuilder &,
308-
Block &, ArrayRef<NamedAttribute>)>
309-
getRegionBuilder() {
310-
return regionBuilder;
311-
}
312-
}];
313-
}
314-
315107
#endif // LINALGX_STRUCTURED_OPS

include/gc/Dialect/Linalgx/Utils.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace mlir {
1818
namespace linalgx {
1919

2020
/// @brief enum of type of matmul packing
21-
enum class PackingType {
21+
enum class PackingType : int {
2222
MM4D = 0, // MKmk x NKkn
2323
VNNI_MM2D, // MK x NKknV
2424
VNNI_MM4D, // MKmk x NKknV
@@ -43,6 +43,30 @@ makeGenericPackedMatmulOp(OpBuilder &builder, Location loc, PackingType opType,
4343
/// @return true if op is a generic packed matmul Op
4444
bool isGenericPackedMatmulOp(Operation *op, PackingType opType);
4545

46+
template <typename... Args>
47+
inline bool isGenericPackedMatmulOp(Operation *op, PackingType first,
48+
Args... args) {
49+
return isGenericPackedMatmulOp(op, first) ||
50+
isGenericPackedMatmulOp(op, args...);
51+
}
52+
53+
/// @brief identify a generic packed matmul Op based on any PackingType
54+
/// @param op the op
55+
/// @return true if op is a generic packed matmul Op
56+
template <int T, int N> inline bool isAnyGenericPackedMatmulOp(Operation *op) {
57+
return isGenericPackedMatmulOp(op, (PackingType)N) ||
58+
isAnyGenericPackedMatmulOp<T + 1, N>(op);
59+
}
60+
constexpr int NUM_ALL_TYPES = (int)PackingType::NUM_TYPES;
61+
template <>
62+
inline bool
63+
isAnyGenericPackedMatmulOp<NUM_ALL_TYPES, NUM_ALL_TYPES>(Operation *op) {
64+
return false;
65+
}
66+
inline bool isAnyGenericPackedMatmulOp(Operation *op) {
67+
return isAnyGenericPackedMatmulOp<0, NUM_ALL_TYPES>(op);
68+
}
69+
4670
/// @brief identify a matmul Op based on ContractionOp and PackingType
4771
/// @param op the op
4872
/// @return true if op is a matmul Op

0 commit comments

Comments
 (0)