@@ -131,59 +131,3 @@ def traverse(OP):
131131
132132 traverse (outs [0 ].op )
133133 return s
134-
135-
136- @avg_pool2d_alter_layout .register (["cuda" ])
137- def _alter_avg_pool2d_layout (attrs , inputs , tinfo ):
138- import nnvm .symbol as sym
139- copy_inputs = [s for s in inputs ]
140- new_attrs = {k : attrs [k ] for k in attrs .keys ()}
141- # NHWC -> NCHW
142- if attrs ["layout" ] != "NHWC" :
143- return None
144- new_attrs ["layout" ] = "NCHW"
145- if "target" in new_attrs :
146- del new_attrs ["target" ]
147- return sym .avg_pool2d (* copy_inputs , ** new_attrs )
148-
149-
150- @max_pool2d_alter_layout .register (["cuda" ])
151- def _alter_max_pool2d_layout (attrs , inputs , tinfo ):
152- import nnvm .symbol as sym
153- copy_inputs = [s for s in inputs ]
154- new_attrs = {k : attrs [k ] for k in attrs .keys ()}
155- # NHWC -> NCHW
156- if attrs ["layout" ] != "NHWC" :
157- return None
158- new_attrs ["layout" ] = "NCHW"
159- if "target" in new_attrs :
160- del new_attrs ["target" ]
161- return sym .max_pool2d (* copy_inputs , ** new_attrs )
162-
163-
164- @global_max_pool2d_alter_layout .register (["cuda" ])
165- def _alter_global_max_pool2d_layout (attrs , inputs , tinfo ):
166- import nnvm .symbol as sym
167- copy_inputs = [s for s in inputs ]
168- new_attrs = {k : attrs [k ] for k in attrs .keys ()}
169- # NHWC -> NCHW
170- if attrs ["layout" ] != "NHWC" :
171- return None
172- new_attrs ["layout" ] = "NCHW"
173- if "target" in new_attrs :
174- del new_attrs ["target" ]
175- return sym .global_max_pool2d (* copy_inputs , ** new_attrs )
176-
177-
178- @global_avg_pool2d_alter_layout .register (["cuda" ])
179- def _alter_global_avg_pool2d_layout (attrs , inputs , tinfo ):
180- import nnvm .symbol as sym
181- copy_inputs = [s for s in inputs ]
182- new_attrs = {k : attrs [k ] for k in attrs .keys ()}
183- # NHWC -> NCHW
184- if attrs ["layout" ] != "NHWC" :
185- return None
186- new_attrs ["layout" ] = "NCHW"
187- if "target" in new_attrs :
188- del new_attrs ["target" ]
189- return sym .global_avg_pool2d (* copy_inputs , ** new_attrs )
0 commit comments