@@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154 let hasVerifier = 1;
155155}
156156
157+ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
158+ let summary = "Winograd filter transform operator";
159+ let description = [{
160+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
161+ matrix multiply. Before the matrix multiply, it will convert filter and
162+ input into a format suitable for batched matrix multiply. After the matrix
163+ multiply, it will convert output to the final result tensor.
164+
165+ The algorithm F(m x m, r x r) is
166+
167+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
168+
169+ The size of output Y is m x m. The size of filter g is r x r. The size of
170+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
171+ transformation matrices.
172+
173+ This operator is defined to represent the high level concept of filter
174+ transformation (G x g x G^T) in the Winograd Conv2D algorithm.
175+ }];
176+
177+ let arguments = (ins AnyRankedTensor:$filter,
178+ AnyRankedTensor:$output,
179+ I64Attr:$m,
180+ I64Attr:$r
181+ );
182+
183+ let results = (outs AnyRankedTensor:$result);
184+ let assemblyFormat = [{
185+ attr-dict
186+ `m` `(` $m `)`
187+ `r` `(` $r `)`
188+ `ins` `(` $filter `:` type($filter) `)`
189+ `outs` `(` $output `:` type($output) `)`
190+ `->` type($result)
191+ }];
192+ let hasVerifier = 1;
193+ }
194+
195+ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
196+ let summary = "Winograd input transform operator";
197+ let description = [{
198+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
199+ matrix multiply. Before the matrix multiply, it will convert filter and
200+ input into a format suitable for batched matrix multiply. After the matrix
201+ multiply, it will convert output to the final result tensor.
202+
203+ The algorithm F(m x m, r x r) is
204+
205+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
206+
207+ The size of output Y is m x m. The size of filter g is r x r. The size of
208+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
209+ transformation matrices.
210+
211+ This operator is defined to represent the high level concept of input
212+ transformation (B^T x d x B) in the Winograd Conv2D algorithm.
213+ }];
214+
215+ let arguments = (ins AnyRankedTensor:$input,
216+ AnyRankedTensor:$output,
217+ I64Attr:$m,
218+ I64Attr:$r
219+ );
220+
221+ let results = (outs AnyRankedTensor:$result);
222+ let assemblyFormat = [{
223+ attr-dict
224+ `m` `(` $m `)`
225+ `r` `(` $r `)`
226+ `ins` `(` $input `:` type($input) `)`
227+ `outs` `(` $output `:` type($output) `)`
228+ `->` type($result)
229+ }];
230+ let hasVerifier = 1;
231+ }
232+
233+ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
234+ let summary = "Winograd output transform operator";
235+ let description = [{
236+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
237+ matrix multiply. Before the matrix multiply, it will convert filter and
238+ input into a format suitable for batched matrix multiply. After the matrix
239+ multiply, it will convert output to the final result tensor.
240+
241+ The algorithm F(m x m, r x r) is
242+
243+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
244+
245+ The size of output Y is m x m. The size of filter g is r x r. The size of
246+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
247+ transformation matrices.
248+
249+ This operator is defined to represent the high level concept of output
250+ transformation (A^T x y x A) in the Winograd Conv2D algorithm.
251+ }];
252+
253+ let arguments = (ins AnyRankedTensor:$value,
254+ AnyRankedTensor:$output,
255+ I64Attr:$m,
256+ I64Attr:$r
257+ );
258+
259+ let results = (outs AnyRankedTensor:$result);
260+ let assemblyFormat = [{
261+ attr-dict
262+ `m` `(` $m `)`
263+ `r` `(` $r `)`
264+ `ins` `(` $value `:` type($value) `)`
265+ `outs` `(` $output `:` type($output) `)`
266+ `->` type($result)
267+ }];
268+ let hasVerifier = 1;
269+ }
270+
157271#endif // LINALG_OPS
0 commit comments