4848 tile ,
4949 transpose ,
5050 where ,
51+ repeat ,
52+ expand_dims ,
53+ full_like
5154)
5255
5356
@@ -198,6 +201,7 @@ def clip_grad(orig, grad):
198201
199202@register_gradient ("nn.max_pool2d" )
200203def max_pool2d_grad (orig , grad ):
204+ """Returns the gradient of max_pool2d."""
201205 attrs = orig .attrs
202206 pool_grad = _nn .max_pool2d_grad (grad , orig .args [0 ], pool_size = attrs .pool_size ,
203207 strides = attrs .strides , padding = attrs .padding ,
@@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad):
207211
208212@register_gradient ("nn.avg_pool2d" )
209213def avg_pool2d_grad (orig , grad ):
214+ """Returns the gradient of avg_pool2d."""
210215 attrs = orig .attrs
211216 pool_grad = _nn .avg_pool2d_grad (grad , orig .args [0 ], pool_size = attrs .pool_size ,
212217 strides = attrs .strides , padding = attrs .padding ,
@@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad):
215220 return [pool_grad ]
216221
217222
223+ @register_gradient ("nn.global_avg_pool2d" )
224+ def global_avg_pool2d_grad (orig , grad ):
225+ """Returns the gradient of global_avg_pool2d."""
226+ data = orig .args [0 ]
227+ shape = data .checked_type .shape
228+ layout = orig .attrs .layout
229+
230+ # we assume NCHW or NHWC layout for now, but easy to add more
231+ assert layout in ["NCHW" , "NHWC" ]
232+ if layout == "NCHW" :
233+ pool_size = shape [2 ], shape [3 ]
234+ elif layout == "NHWC" :
235+ pool_size = shape [1 ], shape [2 ]
236+
237+ pool_grad = _nn .avg_pool2d_grad (grad , data , pool_size = pool_size ,
238+ strides = (1 , 1 ), padding = (0 , 0 ),
239+ layout = layout )
240+ return [pool_grad ]
241+
242+
218243# not implemented, this is only for testing.
219244@register_gradient ("concatenate" )
220245def concatenate_grad (orig , grad ):
@@ -287,16 +312,53 @@ def conv2d_grad(orig, grad):
287312 return [backward_data , backward_weight ]
288313
289314
315+ def _get_reduce_axis (call ):
316+ """Helper function that returns the reduce axis of the call as plain python ints."""
317+ x , axis = call .args [0 ], call .attrs .axis
318+ shape = x .checked_type .concrete_shape
319+
320+ # should never exclude when axis is None
321+ assert not (axis is None and call .attrs .exclude )
322+
323+ if axis is None :
324+ return None
325+
326+ # convert to nonnegative integers and sort
327+ axis = sorted ([ax if ax >= 0 else len (shape ) + ax for ax in map (int , axis )])
328+ if call .attrs .exclude :
329+ axis = [ax for ax in range (len (shape )) if ax not in axis ]
330+ return axis
331+
332+
333+ def _unreduce_expand (x , axis ):
334+ """Helper function that returns x expanded on the reduced dimensions in axis."""
335+ # assume axis is sorted nonnegative ints
336+ for ax in axis :
337+ x = expand_dims (x , ax )
338+ return x
339+
340+
290341@register_gradient ("max" )
291342def max_grad (orig , grad ):
292343 """Returns the gradient of max"""
293- # Only support axis=0, since broadcasting orig to x behaves incorrectly
294- x , axis = orig .args [0 ], orig .attrs .axis
295- assert (axis is not None and len (axis ) == 1 and int (axis [0 ]) == 0 )
296- orig = broadcast_to_like (orig , x )
297- grad = broadcast_to_like (grad , x )
298- indicators = cast_like (equal (orig , x ), grad )
299- return [indicators * grad ]
344+ x , axis = orig .args [0 ], _get_reduce_axis (orig )
345+ shape = x .checked_type .concrete_shape
346+
347+ repeated = orig
348+ if axis is None :
349+ repeated = full_like (x , repeated )
350+ else :
351+ # expand dims (if necessary) and repeat along each axis
352+ if not orig .attrs .keepdims :
353+ repeated = _unreduce_expand (repeated , axis )
354+ grad = _unreduce_expand (grad , axis )
355+ for ax in axis :
356+ repeated = repeat (repeated , shape [ax ], ax )
357+
358+ indicators = cast_like (equal (repeated , x ), grad )
359+ num_selected = _sum (indicators , axis , keepdims = True )
360+ # spread error across all max weights
361+ return [indicators * grad / num_selected ]
300362
301363
302364@register_gradient ("nn.softmax" )
@@ -372,7 +434,11 @@ def negative_grad(orig, grad):
372434@register_gradient ("sum" )
373435def sum_grad (orig , grad ):
374436 """Returns grad broadcasted to data dims"""
375- data = orig .args [0 ]
437+ data , axis = orig .args [0 ], _get_reduce_axis (orig )
438+ if not orig .attrs .keepdims :
439+ if axis is None :
440+ axis = list (range (len (data .checked_type .concrete_shape )))
441+ grad = _unreduce_expand (grad , axis )
376442 return [broadcast_to_like (grad , data )]
377443
378444
0 commit comments