Skip to content

Commit 118acb2

Browse files
altanhwweic
authored andcommitted
[Relay][Training] Add and fix gradients (apache#4126)
* add and fix gradients * fix linter issues
1 parent 2a79b98 commit 118acb2

File tree

3 files changed

+111
-19
lines changed

3 files changed

+111
-19
lines changed

python/tvm/relay/op/_tensor_grad.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
tile,
4949
transpose,
5050
where,
51+
repeat,
52+
expand_dims,
53+
full_like
5154
)
5255

5356

@@ -198,6 +201,7 @@ def clip_grad(orig, grad):
198201

199202
@register_gradient("nn.max_pool2d")
200203
def max_pool2d_grad(orig, grad):
204+
"""Returns the gradient of max_pool2d."""
201205
attrs = orig.attrs
202206
pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
203207
strides=attrs.strides, padding=attrs.padding,
@@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad):
207211

208212
@register_gradient("nn.avg_pool2d")
209213
def avg_pool2d_grad(orig, grad):
214+
"""Returns the gradient of avg_pool2d."""
210215
attrs = orig.attrs
211216
pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
212217
strides=attrs.strides, padding=attrs.padding,
@@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad):
215220
return [pool_grad]
216221

217222

223+
@register_gradient("nn.global_avg_pool2d")
224+
def global_avg_pool2d_grad(orig, grad):
225+
"""Returns the gradient of global_avg_pool2d."""
226+
data = orig.args[0]
227+
shape = data.checked_type.shape
228+
layout = orig.attrs.layout
229+
230+
# we assume NCHW or NHWC layout for now, but easy to add more
231+
assert layout in ["NCHW", "NHWC"]
232+
if layout == "NCHW":
233+
pool_size = shape[2], shape[3]
234+
elif layout == "NHWC":
235+
pool_size = shape[1], shape[2]
236+
237+
pool_grad = _nn.avg_pool2d_grad(grad, data, pool_size=pool_size,
238+
strides=(1, 1), padding=(0, 0),
239+
layout=layout)
240+
return [pool_grad]
241+
242+
218243
# not implemented, this is only for testing.
219244
@register_gradient("concatenate")
220245
def concatenate_grad(orig, grad):
@@ -287,16 +312,53 @@ def conv2d_grad(orig, grad):
287312
return [backward_data, backward_weight]
288313

289314

315+
def _get_reduce_axis(call):
316+
"""Helper function that returns the reduce axis of the call as plain python ints."""
317+
x, axis = call.args[0], call.attrs.axis
318+
shape = x.checked_type.concrete_shape
319+
320+
# should never exclude when axis is None
321+
assert not (axis is None and call.attrs.exclude)
322+
323+
if axis is None:
324+
return None
325+
326+
# convert to nonnegative integers and sort
327+
axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)])
328+
if call.attrs.exclude:
329+
axis = [ax for ax in range(len(shape)) if ax not in axis]
330+
return axis
331+
332+
333+
def _unreduce_expand(x, axis):
334+
"""Helper function that returns x expanded on the reduced dimensions in axis."""
335+
# assume axis is sorted nonnegative ints
336+
for ax in axis:
337+
x = expand_dims(x, ax)
338+
return x
339+
340+
290341
@register_gradient("max")
291342
def max_grad(orig, grad):
292343
"""Returns the gradient of max"""
293-
# Only support axis=0, since broadcasting orig to x behaves incorrectly
294-
x, axis = orig.args[0], orig.attrs.axis
295-
assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
296-
orig = broadcast_to_like(orig, x)
297-
grad = broadcast_to_like(grad, x)
298-
indicators = cast_like(equal(orig, x), grad)
299-
return [indicators * grad]
344+
x, axis = orig.args[0], _get_reduce_axis(orig)
345+
shape = x.checked_type.concrete_shape
346+
347+
repeated = orig
348+
if axis is None:
349+
repeated = full_like(x, repeated)
350+
else:
351+
# expand dims (if necessary) and repeat along each axis
352+
if not orig.attrs.keepdims:
353+
repeated = _unreduce_expand(repeated, axis)
354+
grad = _unreduce_expand(grad, axis)
355+
for ax in axis:
356+
repeated = repeat(repeated, shape[ax], ax)
357+
358+
indicators = cast_like(equal(repeated, x), grad)
359+
num_selected = _sum(indicators, axis, keepdims=True)
360+
# spread error across all max weights
361+
return [indicators * grad / num_selected]
300362

301363

302364
@register_gradient("nn.softmax")
@@ -372,7 +434,11 @@ def negative_grad(orig, grad):
372434
@register_gradient("sum")
373435
def sum_grad(orig, grad):
374436
"""Returns grad broadcasted to data dims"""
375-
data = orig.args[0]
437+
data, axis = orig.args[0], _get_reduce_axis(orig)
438+
if not orig.attrs.keepdims:
439+
if axis is None:
440+
axis = list(range(len(data.checked_type.concrete_shape)))
441+
grad = _unreduce_expand(grad, axis)
376442
return [broadcast_to_like(grad, data)]
377443

378444

tests/python/relay/test_op_grad_level2.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
4848

4949

5050
def test_max_pool2d_grad():
51-
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
52-
ceil_mode=False)
51+
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False)
5352
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False)
5453

5554

@@ -75,14 +74,37 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun
7574
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
7675
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
7776

78-
7977
def test_avg_pool2d_grad():
8078
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
8179
ceil_mode=False, count_include_pad=True)
8280
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1),
8381
ceil_mode=False, count_include_pad=False)
8482

8583

84+
def verify_global_avg_pool2d_grad(x_shape):
85+
x = relay.var("x", relay.TensorType(x_shape, "float32"))
86+
y = tvm.relay.nn.global_avg_pool2d(x)
87+
88+
fwd_func = relay.Function([x], y)
89+
fwd_func = run_infer_type(fwd_func)
90+
bwd_func = run_infer_type(gradient(fwd_func))
91+
92+
data = np.random.rand(*x_shape).astype("float32")
93+
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
94+
out_grad = np.ones(shape=y_shape)
95+
ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=(x_shape[2], x_shape[3]),
96+
strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg',
97+
ceil_mode=False)
98+
99+
for target, ctx in ctx_list():
100+
intrp = relay.create_executor(ctx=ctx, target=target)
101+
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
102+
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
103+
104+
def test_global_avg_pool2d_grad():
105+
verify_global_avg_pool2d_grad((1, 4, 16, 16))
106+
verify_global_avg_pool2d_grad((1, 8, 8, 24))
107+
86108
def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'):
87109
try:
88110
import torch
@@ -155,6 +177,7 @@ def test_batch_flatten_grad():
155177
if __name__ == "__main__":
156178
test_max_pool2d_grad()
157179
test_avg_pool2d_grad()
180+
test_global_avg_pool2d_grad()
158181
test_conv2d_grad()
159182
test_dense_grad()
160183
test_batch_flatten_grad()

tests/python/relay/test_op_grad_level4.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,21 @@ def test_sum_grad():
2929
verify_sum_grad((4, 2))
3030
verify_sum_grad((4, 2), axis=-1, keepdims=True)
3131
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
32+
verify_sum_grad((4, 2, 1), axis=1)
3233

3334

34-
def test_max_grad():
35-
s = (10, 10)
36-
t = relay.TensorType(s)
37-
x = relay.var("x", t)
38-
axis = 0
39-
z = relay.max(x, axis)
40-
41-
fwd_func = relay.Function([x], z)
35+
def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False):
36+
data = relay.var("data", relay.TensorType(d_shape, "float32"))
37+
fwd_func = relay.Function([data], relay.max(data, axis=axis, keepdims=keepdims, exclude=exclude))
4238
check_grad(fwd_func, scale=1e-3)
4339

4440

41+
def test_max_grad():
42+
verify_max_grad((10, 10), axis=None)
43+
verify_max_grad((10, 10), axis=-1)
44+
verify_max_grad((6, 3, 2), axis=(1, 2), keepdims=True)
45+
verify_max_grad((5, 4, 3), axis=(0, 2), exclude=True)
46+
47+
4548
if __name__ == "__main__":
4649
pytest.main()

0 commit comments

Comments
 (0)