Skip to content

Commit c68bfc3

Browse files
authored
Merge pull request PaddlePaddle#3476 from qingqing01/bp_test
Compare the gradient consistency between GPU and CPU calculations
2 parents 766299b + b1ac863 commit c68bfc3

File tree

6 files changed

+177
-103
lines changed

6 files changed

+177
-103
lines changed

paddle/operators/sigmoid_op.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {
4444

4545
protected:
4646
void InferShape(const framework::InferShapeContext &ctx) const override {
47-
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
47+
ctx.Output<Tensor>(framework::GradVarName("X"))
48+
->Resize(ctx.Input<Tensor>("Y")->dims());
4849
}
4950
};
5051

paddle/operators/sigmoid_op.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class SigmoidKernel : public framework::OpKernel {
3737
auto Y = EigenVector<T>::Flatten(*output);
3838
auto place = context.GetEigenDevice<Place>();
3939

40-
Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp());
40+
Y.device(place) = 1. / (1. + (-X).exp());
4141
}
4242
};
4343

python/paddle/v2/framework/tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ py_test(test_operator SRCS test_operator.py)
2525
# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
2626
py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
2727
py_test(test_recurrent_op SRCS test_recurrent_op.py)
28+
py_test(test_gradient_checker SRCS test_gradient_checker.py)

python/paddle/v2/framework/tests/gradient_checker.py

+117-97
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import unittest
22

33
import numpy
4+
import itertools
45
import paddle.v2.framework.core as core
56
from paddle.v2.framework.op import Operator
67

78
__all__ = ['get_numeric_gradient']
89

910

1011
def create_op(op_type):
12+
# TODO need to set attrs
1113
kwargs = dict()
1214
for in_name in Operator.get_op_input_names(op_type):
1315
kwargs[in_name] = in_name
@@ -66,7 +68,6 @@ def get_numeric_gradient(op,
6668
local_scope.find_var(output).get_tensor().alloc_float(core.CPUPlace(
6769
))
6870

69-
# TODO(yuyang18): Only CPU is support now.
7071
cpu_ctx = core.DeviceContext.create(core.CPUPlace())
7172

7273
def get_output():
@@ -109,12 +110,110 @@ def product(dim):
109110

110111

111112
class GradientChecker(unittest.TestCase):
112-
def assert_is_close(self, numeric_grads, scope, max_relative_error,
113-
msg_prefix):
114-
for name in numeric_grads:
115-
b = numpy.array(scope.find_var(grad_var_name(name)).get_tensor())
116-
a = numeric_grads[name]
113+
def __get_gradient(self, forward_op, backward_op, input_value, grad_names,
114+
place):
115+
"""Get the input gradients after running forward and backward operators
116+
on the given places.
117+
118+
:param forward_op: forward operator
119+
:type forward_op: Operator
120+
:param backward_op: backward operator
121+
:type backward_op: Operator
122+
:param input_value: input values.
123+
:type input_value: dict{string:numpy.array}
124+
:param grad_names: the names of returned input gradients.
125+
:type input_value: a list of string
126+
:param place: the device type.
127+
:type place: CPUPlace or GPUPlace
128+
:return: the input grdients of given grad_names.
129+
:rtype: a list of numpy.array
130+
"""
131+
scope = core.Scope()
132+
ctx = core.DeviceContext.create(place)
133+
134+
inputs = forward_op.inputs()
135+
in_names = [item for k in inputs for item in inputs[k]]
136+
outputs = forward_op.outputs()
137+
out_names = [item for k in outputs for item in outputs[k]]
138+
139+
# create input var and set value
140+
for name, value in input_value.iteritems():
141+
if name not in in_names:
142+
raise ValueError(name + "does not exist in Op's inputs.")
143+
var = scope.new_var(name).get_tensor()
144+
var.set_dims(value.shape)
145+
var.set(value, place)
146+
147+
# run forward op
148+
for out_name in out_names:
149+
scope.new_var(out_name)
150+
forward_op.infer_shape(scope)
151+
forward_op.run(scope, ctx)
152+
153+
# set output var's shape
154+
# set output grad to ones
155+
for name in out_names:
156+
out_tensor = scope.find_var(name).get_tensor()
157+
grad_tensor = scope.new_var(grad_var_name(name)).get_tensor()
158+
grad_tensor.set_dims(out_tensor.shape())
159+
data = numpy.ones(out_tensor.shape(), dtype=numpy.float32)
160+
grad_tensor.set(data, place)
161+
162+
# run backward op
163+
for name in backward_op.outputs():
164+
scope.new_var(name)
165+
backward_op.infer_shape(scope)
166+
backward_op.run(scope, ctx)
167+
168+
outs = [
169+
numpy.array(scope.find_var(name).get_tensor())
170+
for name in grad_names
171+
]
172+
return outs
173+
174+
def compare_grad(self, forward_op, input_value):
175+
""" Compare the input gradients between CPU and GPU for the given forward
176+
operator.
177+
178+
:param forward_op: forward operator
179+
:type forward_op: Operator
180+
:param input_value: input values.
181+
:type input_value: dict{string:numpy.array}
182+
:raises: AssertionError, there is different gradient value.
183+
"""
184+
backward_op = core.Operator.backward(forward_op, set())
185+
# return if not compile with GPU or not implementing GPU kernel
186+
if not (core.is_compile_gpu() and backward_op.support_gpu()):
187+
return
117188

189+
outputs = backward_op.outputs()
190+
out_names = [item for k in outputs for item in outputs[k]]
191+
cpu_grads = self.__get_gradient(forward_op, backward_op, input_value,
192+
out_names, core.CPUPlace())
193+
gpu_grads = self.__get_gradient(forward_op, backward_op, input_value,
194+
out_names, core.GPUPlace(0))
195+
196+
for c_grad, g_grad, name in itertools.izip(cpu_grads, gpu_grads,
197+
out_names):
198+
self.assertTrue(
199+
numpy.allclose(
200+
c_grad, g_grad, atol=1e-4),
201+
"output name: " + name + " has diff")
202+
203+
def __assert_is_close(self, numeric_grads, analytic_grads, names,
204+
max_relative_error, msg_prefix):
205+
"""Use relative error for the comparison.
206+
207+
:param numeric_grads: the numerical graidents.
208+
:type numeric_grads: a list of numpy.array
209+
:param analytic_grads: the analytical graidents.
210+
:type analytic_grads: a list of numpy.array
211+
:param name: the names of gradients, used to print for debug.
212+
:type names: a list of string
213+
:param msg_prefix: string info, used to print for debug.
214+
:type msf_prefix: string
215+
"""
216+
for a, b, name in itertools.izip(numeric_grads, analytic_grads, names):
118217
abs_a = numpy.abs(a)
119218
# if abs_a is nearly zero, then use abs error for a, not relative
120219
# error.
@@ -159,105 +258,26 @@ def check_grad(self,
159258

160259
inputs = forward_op.inputs()
161260
in_names = [item for k in inputs for item in inputs[k]]
162-
outputs = forward_op.outputs()
163-
out_names = [item for k in outputs for item in outputs[k]]
164-
165261
for no_grad in no_grad_set:
166262
if no_grad not in in_names:
167263
raise ValueError("no_grad should be in in_names")
168264
backward_op = core.Operator.backward(forward_op, no_grad_set)
169265

170-
bwd_outputs = backward_op.outputs()
171-
bwd_out_names = [item for k in bwd_outputs for item in bwd_outputs[k]]
172-
173266
places = [core.CPUPlace()]
174267
if not only_cpu and core.is_compile_gpu() and backward_op.support_gpu():
175268
places.append(core.GPUPlace(0))
176269

177-
numeric_grad = dict()
178-
# get numeric gradient
179-
for check_name in inputs_to_check:
180-
numeric_grad[check_name] = \
181-
get_numeric_gradient(forward_op, input_vars, output_name,
182-
check_name)
270+
# get numerical gradients
271+
numeric_grads = [
272+
get_numeric_gradient(forward_op, input_vars, output_name, name)
273+
for name in inputs_to_check
274+
]
183275

184-
# get operator gradient according to different device
276+
check_names = [grad_var_name(name) for name in inputs_to_check]
185277
for place in places:
186-
scope = core.Scope()
187-
ctx = core.DeviceContext.create(place)
188-
189-
# create input var and set value
190-
for name, value in input_vars.iteritems():
191-
if name not in in_names:
192-
raise ValueError(name + " not in op.inputs_")
193-
var = scope.new_var(name).get_tensor()
194-
var.set_dims(value.shape)
195-
var.set(value, place)
196-
197-
# create output var
198-
for out_name in out_names:
199-
scope.new_var(out_name).get_tensor()
200-
201-
# infer the shape of output var and compute/set value of output var
202-
forward_op.infer_shape(scope)
203-
forward_op.run(scope, ctx)
204-
205-
# create output grad var
206-
# set shape as the output var
207-
# set value of this grad to ones
208-
for name in out_names:
209-
out_tensor = scope.find_var(name).get_tensor()
210-
grad_tensor = scope.new_var(grad_var_name(name)).get_tensor()
211-
grad_tensor.set_dims(out_tensor.shape())
212-
data = 1.0 * numpy.ones(out_tensor.shape())
213-
grad_tensor.set(data, place)
214-
215-
# create input grad var
216-
for name in bwd_out_names:
217-
scope.new_var(name).get_tensor()
218-
219-
# infer the shape of input gradient var and compute/set it's value
220-
# with backward op
221-
backward_op.infer_shape(scope)
222-
backward_op.run(scope, ctx)
223-
224-
self.assert_is_close(numeric_grad, scope, max_relative_error,
225-
"Gradient Check On %s" % str(place))
226-
227-
228-
if __name__ == '__main__':
229-
230-
class GetNumericGradientTest(unittest.TestCase):
231-
def test_add_op(self):
232-
add_op = Operator('add_two', X="X", Y="Y", Out="Z")
233-
x = numpy.random.random((10, 1)).astype("float32")
234-
y = numpy.random.random((10, 1)).astype("float32")
235-
236-
arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X')
237-
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2)
238-
239-
def test_softmax_op(self):
240-
def stable_softmax(x):
241-
"""Compute the softmax of vector x in a numerically stable way."""
242-
shiftx = x - numpy.max(x)
243-
exps = numpy.exp(shiftx)
244-
return exps / numpy.sum(exps)
245-
246-
def label_softmax_grad(Y, dY):
247-
dX = Y * 0.0
248-
for i in range(Y.shape[0]):
249-
d = numpy.dot(Y[i, :], dY[i, :])
250-
dX[i, :] = Y[i, :] * (dY[i, :] - d)
251-
return dX
252-
253-
softmax_op = Operator("softmax", X="X", Y="Y")
254-
255-
X = numpy.random.random((2, 2)).astype("float32")
256-
Y = numpy.apply_along_axis(stable_softmax, 1, X)
257-
dY = numpy.ones(Y.shape)
258-
dX = label_softmax_grad(Y, dY)
259-
260-
arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X')
261-
numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2)
262-
263-
unittest.main()
278+
# get analytical gradients according to different device
279+
analytic_grads = self.__get_gradient(forward_op, backward_op,
280+
input_vars, check_names, place)
281+
self.__assert_is_close(numeric_grads, analytic_grads, check_names,
282+
max_relative_error,
283+
"Gradient Check On %s" % str(place))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import unittest
2+
import numpy
3+
from paddle.v2.framework.op import Operator
4+
from gradient_checker import GradientChecker
5+
from gradient_checker import get_numeric_gradient
6+
7+
8+
class GetNumericGradientTest(unittest.TestCase):
9+
def test_add_op(self):
10+
add_op = Operator('add_two', X="X", Y="Y", Out="Z")
11+
x = numpy.random.random((10, 1)).astype("float32")
12+
y = numpy.random.random((10, 1)).astype("float32")
13+
14+
arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X')
15+
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4)
16+
17+
def test_softmax_op(self):
18+
def stable_softmax(x):
19+
"""Compute the softmax of vector x in a numerically stable way."""
20+
shiftx = x - numpy.max(x)
21+
exps = numpy.exp(shiftx)
22+
return exps / numpy.sum(exps)
23+
24+
def label_softmax_grad(Y, dY):
25+
dX = Y * 0.0
26+
for i in range(Y.shape[0]):
27+
d = numpy.dot(Y[i, :], dY[i, :])
28+
dX[i, :] = Y[i, :] * (dY[i, :] - d)
29+
return dX
30+
31+
softmax_op = Operator("softmax", X="X", Y="Y")
32+
33+
X = numpy.random.random((2, 2)).astype("float32")
34+
Y = numpy.apply_along_axis(stable_softmax, 1, X)
35+
dY = numpy.ones(Y.shape)
36+
dX = label_softmax_grad(Y, dY)
37+
38+
arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X')
39+
numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2)
40+
41+
42+
if __name__ == '__main__':
43+
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
import unittest
2-
from op_test_util import OpTestMeta
32
import numpy as np
3+
from op_test_util import OpTestMeta
4+
from gradient_checker import GradientChecker, create_op
45

56

67
class TestSigmoidOp(unittest.TestCase):
78
__metaclass__ = OpTestMeta
89

910
def setUp(self):
1011
self.type = "sigmoid"
11-
self.inputs = {'X': np.random.random((32, 100)).astype("float32")}
12+
self.inputs = {'X': np.random.random((15, 31)).astype("float32")}
1213
self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
1314

1415

15-
#class TestSigmoidGradOp(unittest.TestCase):
16-
#TODO(qingqing) add unit test
16+
class TestSigmoidGradOp(GradientChecker):
17+
def test_grad(self):
18+
op = create_op("sigmoid")
19+
inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")}
20+
# compare gpu and cpu results for backward op.
21+
# this test will be skiped if only compiling CPU version.
22+
self.compare_grad(op, inputs)
23+
# check gradients
24+
self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007)
25+
1726

1827
if __name__ == '__main__':
1928
unittest.main()

0 commit comments

Comments
 (0)