Skip to content

Commit f6866ea

Browse files
[Prim] Add index_select_double_grad (PaddlePaddle#71352)
* add UT: test_selected_high_order_derivative * remove default axis = 0 * update error msg * update op_compat.yaml * update UT * update code * only run UT in gpu
1 parent 678afc3 commit f6866ea

File tree

5 files changed

+84
-8
lines changed

5 files changed

+84
-8
lines changed

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,9 +1029,17 @@ def BackwardValidationCheck(self):
10291029

10301030
max_grad_tensor_position = -1
10311031
for _, (_, _, pos) in backward_grad_inputs_map.items():
1032-
assert pos > max_fwd_input_position, AssertMessage(
1033-
pos, max_fwd_input_position
1034-
)
1032+
if pos <= max_fwd_input_position:
1033+
err_msg = AssertMessage(pos, max_fwd_input_position)
1034+
if IsInvokeForwardApi(
1035+
self.grad_api_contents, self.forward_apis_dict
1036+
):
1037+
err_msg += (
1038+
f"\n\nNOTE: '{self.backward_api_name}' is an invoke api, "
1039+
"please ensure that the parameters from `forward` "
1040+
"are placed at the front in the `args` section.\n"
1041+
)
1042+
raise AssertionError(err_msg)
10351043
max_grad_tensor_position = max(max_grad_tensor_position, pos)
10361044

10371045
max_attr_position = -1

paddle/phi/ops/yaml/backward.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,12 @@
16721672
data_transform :
16731673
skip_transform : index
16741674

1675+
- backward_op : index_select_double_grad
1676+
forward : index_select_grad (Tensor x, Tensor index, Tensor grad_out, int axis) -> Tensor(grad_x)
1677+
args : (Tensor index, Tensor grad_x_grad, int axis)
1678+
output : Tensor(grad_out_grad)
1679+
invoke : index_select(grad_x_grad, index, axis)
1680+
16751681
- backward_op : index_select_grad
16761682
forward : index_select(Tensor x, Tensor index, int axis) -> Tensor(out)
16771683
args : (Tensor x, Tensor index, Tensor out_grad, int axis)
@@ -1685,6 +1691,7 @@
16851691
no_need_buffer : x
16861692
data_transform :
16871693
skip_transform : index
1694+
backward: index_select_double_grad
16881695

16891696
- backward_op : index_select_strided_grad
16901697
forward : index_select_strided(Tensor x, int64_t index, int axis) -> Tensor(out)

paddle/phi/ops/yaml/op_compat.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,6 +2044,7 @@
20442044
out : Out
20452045

20462046
- op : index_select
2047+
backward : index_select_grad, index_select_double_grad
20472048
inputs :
20482049
{x : X, index : Index}
20492050
outputs :

test/legacy_test/gradient_checker.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ def _product(t):
3030
return int(np.prod(t))
3131

3232

33+
# data type like int32, int64, bool, that do not requires grad
34+
DTYPE_REQUIRES_GRAD = [
35+
paddle.float16,
36+
paddle.float32,
37+
paddle.float64,
38+
core.DataType.FLOAT16,
39+
core.DataType.FLOAT32,
40+
core.DataType.FLOAT64,
41+
]
42+
43+
3344
def dtype_to_np_dtype(dtype):
3445
if dtype == paddle.float32 or dtype == core.DataType.FLOAT32:
3546
return np.float32
@@ -84,7 +95,7 @@ def var_to_np_array_in_scope(scope, place, name):
8495

8596
def make_jacobian(x, y_size, np_dtype):
8697
if isinstance(x, (base.framework.Variable, paddle.pir.Value)):
87-
return np.zeros((_product(x.shape), y_size), dtype=np_dtype)
98+
return np.zeros([_product(x.shape), y_size], dtype=np_dtype)
8899
elif isinstance(x, Sequence):
89100
jacobians = list(
90101
filter(
@@ -260,10 +271,15 @@ def run():
260271
x_name = x.get_defining_op().attrs()['name']
261272
x_shape = x.shape
262273
x_size = _product(x_shape)
263-
np_type = dtype_to_np_dtype(x.dtype)
264-
np_t = np.array(feeds[x_name]).astype(np_type)
265-
np_t = np_t.flatten()
266-
jacobian = [make_jacobian(x, _product(yi.shape), np_type) for yi in y]
274+
if x.dtype in DTYPE_REQUIRES_GRAD:
275+
np_type = dtype_to_np_dtype(x.dtype)
276+
np_t = np.array(feeds[x_name]).astype(np_type)
277+
np_t = np_t.flatten()
278+
jacobian = [make_jacobian(x, _product(yi.shape), np_type) for yi in y]
279+
else:
280+
np_type = np.float32 # temporarily set to float32
281+
jacobian = [make_jacobian(x, _product(yi.shape), np_type) for yi in y]
282+
return jacobian
267283

268284
for i in range(x_size):
269285
orig = np_t[i]

test/legacy_test/test_nn_grad.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,50 @@ def test_grad(self):
551551
self.func(p)
552552

553553

554+
class TestIndexSelectDoubleGradCheck(unittest.TestCase):
555+
@prog_scope()
556+
def func(self, place):
557+
x_shape = [2, 2, 2, 2]
558+
axis = 2
559+
index_shape = [3]
560+
dtype = np.float64
561+
562+
x = paddle.static.data('x', x_shape, dtype)
563+
x.persistable = True
564+
x.stop_gradient = False
565+
index = paddle.static.data('index', index_shape, 'int64')
566+
index.persistable = True
567+
out = paddle.index_select(x, index, axis)
568+
569+
x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
570+
index_arr = np.random.uniform(
571+
-x_shape[axis], x_shape[axis], index_shape
572+
).astype('int64')
573+
gradient_checker.double_grad_check(
574+
[x, index], out, x_init=[x_arr, index_arr], place=place
575+
)
576+
577+
def index_select_wrapper(args):
578+
return paddle.index_select(*args, axis=axis)
579+
580+
gradient_checker.double_grad_check_for_dygraph(
581+
index_select_wrapper,
582+
[x, index],
583+
out,
584+
x_init=[x_arr, index_arr],
585+
place=place,
586+
)
587+
588+
def test_grad(self):
589+
places = []
590+
# free(): invalid next size (fast) may occurs when
591+
# execute in CPU
592+
if core.is_compiled_with_cuda():
593+
places.append(base.CUDAPlace(0))
594+
for p in places:
595+
self.func(p)
596+
597+
554598
class TestAvgPool2DDoubleGradCheckCase1(unittest.TestCase):
555599

556600
@prog_scope()

0 commit comments

Comments
 (0)