@@ -155,7 +155,12 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
155155}
156156
157157def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158- [AllElementTypesMatch<["filter", "output"]>]> {
158+ [AllElementTypesMatch<["filter", "output"]>,
159+ DeclareOpInterfaceMethods<TilingInterface,
160+ ["getIterationDomain",
161+ "getLoopIteratorTypes",
162+ "getResultTilePosition",
163+ "getTiledImplementation"]>]> {
159164 let summary = "Winograd filter transform operator";
160165 let description = [{
161166 Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -190,6 +195,20 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
190195 `outs` `(` $output `:` type($output) `)`
191196 `->` type($result)
192197 }];
198+ let extraClassDeclaration = [{
199+ ShapedType getFilterOperandType() {
200+ return cast<ShapedType>(getFilter().getType());
201+ }
202+ ShapedType getOutputOperandType() {
203+ return cast<ShapedType>(getOutput().getType());
204+ }
205+ int64_t getFilterOperandRank() {
206+ return getFilterOperandType().getRank();
207+ }
208+ int64_t getOutputOperandRank() {
209+ return getOutputOperandType().getRank();
210+ }
211+ }];
193212 let hasVerifier = 1;
194213}
195214
@@ -234,6 +253,20 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
234253 `outs` `(` $output `:` type($output) `)`
235254 `->` type($result)
236255 }];
256+ let extraClassDeclaration = [{
257+ ShapedType getInputOperandType() {
258+ return cast<ShapedType>(getInput().getType());
259+ }
260+ ShapedType getOutputOperandType() {
261+ return cast<ShapedType>(getOutput().getType());
262+ }
263+ int64_t getInputOperandRank() {
264+ return getInputOperandType().getRank();
265+ }
266+ int64_t getOutputOperandRank() {
267+ return getOutputOperandType().getRank();
268+ }
269+ }];
237270 let hasVerifier = 1;
238271}
239272
@@ -278,6 +311,20 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
278311 `outs` `(` $output `:` type($output) `)`
279312 `->` type($result)
280313 }];
314+ let extraClassDeclaration = [{
315+ ShapedType getValueOperandType() {
316+ return cast<ShapedType>(getValue().getType());
317+ }
318+ ShapedType getOutputOperandType() {
319+ return cast<ShapedType>(getOutput().getType());
320+ }
321+ int64_t getValueOperandRank() {
322+ return getValueOperandType().getRank();
323+ }
324+ int64_t getOutputOperandRank() {
325+ return getOutputOperandType().getRank();
326+ }
327+ }];
281328 let hasVerifier = 1;
282329}
283330
0 commit comments