Skip to content

Commit 73696ee

Browse files
committed
address Max191's comments
1 parent e972232 commit 73696ee

File tree

5 files changed

+664
-75
lines changed

5 files changed

+664
-75
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,12 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
155155
}
156156

157157
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158-
[AllElementTypesMatch<["filter", "output"]>]> {
158+
[AllElementTypesMatch<["filter", "output"]>,
159+
DeclareOpInterfaceMethods<TilingInterface,
160+
["getIterationDomain",
161+
"getLoopIteratorTypes",
162+
"getResultTilePosition",
163+
"getTiledImplementation"]>]> {
159164
let summary = "Winograd filter transform operator";
160165
let description = [{
161166
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -190,6 +195,20 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
190195
`outs` `(` $output `:` type($output) `)`
191196
`->` type($result)
192197
}];
198+
let extraClassDeclaration = [{
199+
ShapedType getFilterOperandType() {
200+
return cast<ShapedType>(getFilter().getType());
201+
}
202+
ShapedType getOutputOperandType() {
203+
return cast<ShapedType>(getOutput().getType());
204+
}
205+
int64_t getFilterOperandRank() {
206+
return getFilterOperandType().getRank();
207+
}
208+
int64_t getOutputOperandRank() {
209+
return getOutputOperandType().getRank();
210+
}
211+
}];
193212
let hasVerifier = 1;
194213
}
195214

@@ -234,6 +253,20 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
234253
`outs` `(` $output `:` type($output) `)`
235254
`->` type($result)
236255
}];
256+
let extraClassDeclaration = [{
257+
ShapedType getInputOperandType() {
258+
return cast<ShapedType>(getInput().getType());
259+
}
260+
ShapedType getOutputOperandType() {
261+
return cast<ShapedType>(getOutput().getType());
262+
}
263+
int64_t getInputOperandRank() {
264+
return getInputOperandType().getRank();
265+
}
266+
int64_t getOutputOperandRank() {
267+
return getOutputOperandType().getRank();
268+
}
269+
}];
237270
let hasVerifier = 1;
238271
}
239272

@@ -278,6 +311,20 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
278311
`outs` `(` $output `:` type($output) `)`
279312
`->` type($result)
280313
}];
314+
let extraClassDeclaration = [{
315+
ShapedType getValueOperandType() {
316+
return cast<ShapedType>(getValue().getType());
317+
}
318+
ShapedType getOutputOperandType() {
319+
return cast<ShapedType>(getOutput().getType());
320+
}
321+
int64_t getValueOperandRank() {
322+
return getValueOperandType().getRank();
323+
}
324+
int64_t getOutputOperandRank() {
325+
return getOutputOperandType().getRank();
326+
}
327+
}];
281328
let hasVerifier = 1;
282329
}
283330

0 commit comments

Comments
 (0)