@@ -154,8 +154,13 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154 let hasVerifier = 1;
155155}
156156
157- def Linalg_WinogradFilterTransformOp :
158- Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
157+ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158+ [AllElementTypesMatch<["filter", "output"]>,
159+ DeclareOpInterfaceMethods<TilingInterface,
160+ ["getIterationDomain",
161+ "getLoopIteratorTypes",
162+ "getResultTilePosition",
163+ "getTiledImplementation"]>]> {
159164 let summary = "Winograd filter transform operator";
160165 let description = [{
161166 Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -190,11 +195,42 @@ def Linalg_WinogradFilterTransformOp :
190195 `outs` `(` $output `:` type($output) `)`
191196 `->` type($result)
192197 }];
198+ let extraClassDeclaration = [{
199+ ShapedType getFilterOperandType() {
200+ return cast<ShapedType>(getFilter().getType());
201+ }
202+ ShapedType getOutputOperandType() {
203+ return cast<ShapedType>(getOutput().getType());
204+ }
205+ int64_t getFilterOperandRank() {
206+ return getFilterOperandType().getRank();
207+ }
208+ int64_t getOutputOperandRank() {
209+ return getOutputOperandType().getRank();
210+ }
211+ int64_t getFilterFDim() {
212+ return 0;
213+ }
214+ int64_t getFilterHDim() {
215+ return 1;
216+ }
217+ int64_t getFilterWDim() {
218+ return 2;
219+ }
220+ int64_t getFilterCDim() {
221+ return 3;
222+ }
223+ }];
193224 let hasVerifier = 1;
194225}
195226
196- def Linalg_WinogradInputTransformOp :
197- Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
227+ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
228+ [AllElementTypesMatch<["input", "output"]>,
229+ DeclareOpInterfaceMethods<TilingInterface,
230+ ["getIterationDomain",
231+ "getLoopIteratorTypes",
232+ "getResultTilePosition",
233+ "getTiledImplementation"]>]> {
198234 let summary = "Winograd input transform operator";
199235 let description = [{
200236 Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -229,11 +265,60 @@ def Linalg_WinogradInputTransformOp :
229265 `outs` `(` $output `:` type($output) `)`
230266 `->` type($result)
231267 }];
268+ let extraClassDeclaration = [{
269+ ShapedType getInputOperandType() {
270+ return cast<ShapedType>(getInput().getType());
271+ }
272+ ShapedType getOutputOperandType() {
273+ return cast<ShapedType>(getOutput().getType());
274+ }
275+ int64_t getInputOperandRank() {
276+ return getInputOperandType().getRank();
277+ }
278+ int64_t getOutputOperandRank() {
279+ return getOutputOperandType().getRank();
280+ }
281+ int64_t getInputNDim() {
282+ return 0;
283+ }
284+ int64_t getInputHDim() {
285+ return 1;
286+ }
287+ int64_t getInputWDim() {
288+ return 2;
289+ }
290+ int64_t getInputCDim() {
291+ return 3;
292+ }
293+ int64_t getOutputAlphaHDim() {
294+ return 0;
295+ }
296+ int64_t getOutputAlphaWDim() {
297+ return 1;
298+ }
299+ int64_t getOutputTileHDim() {
300+ return 2;
301+ }
302+ int64_t getOutputTileWDim() {
303+ return 3;
304+ }
305+ int64_t getOutputNDim() {
306+ return 4;
307+ }
308+ int64_t getOutputCDim() {
309+ return 5;
310+ }
311+ }];
232312 let hasVerifier = 1;
233313}
234314
235- def Linalg_WinogradOutputTransformOp :
236- Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
315+ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
316+ [AllElementTypesMatch<["value", "output"]>,
317+ DeclareOpInterfaceMethods<TilingInterface,
318+ ["getIterationDomain",
319+ "getLoopIteratorTypes",
320+ "getResultTilePosition",
321+ "getTiledImplementation"]>]> {
237322 let summary = "Winograd output transform operator";
238323 let description = [{
239324 Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -268,6 +353,50 @@ def Linalg_WinogradOutputTransformOp :
268353 `outs` `(` $output `:` type($output) `)`
269354 `->` type($result)
270355 }];
356+ let extraClassDeclaration = [{
357+ ShapedType getValueOperandType() {
358+ return cast<ShapedType>(getValue().getType());
359+ }
360+ ShapedType getOutputOperandType() {
361+ return cast<ShapedType>(getOutput().getType());
362+ }
363+ int64_t getValueOperandRank() {
364+ return getValueOperandType().getRank();
365+ }
366+ int64_t getOutputOperandRank() {
367+ return getOutputOperandType().getRank();
368+ }
369+ int64_t getValueAlphaHDim() {
370+ return 0;
371+ }
372+ int64_t getValueAlphaWDim() {
373+ return 1;
374+ }
375+ int64_t getValueTileHDim() {
376+ return 2;
377+ }
378+ int64_t getValueTileWDim() {
379+ return 3;
380+ }
381+ int64_t getValueNDim() {
382+ return 4;
383+ }
384+ int64_t getValueFDim() {
385+ return 5;
386+ }
387+ int64_t getOutputNDim() {
388+ return 0;
389+ }
390+ int64_t getOutputHDim() {
391+ return 1;
392+ }
393+ int64_t getOutputWDim() {
394+ return 2;
395+ }
396+ int64_t getOutputFDim() {
397+ return 3;
398+ }
399+ }];
271400 let hasVerifier = 1;
272401}
273402
0 commit comments