Skip to content

Commit f124c86

Browse files
JiabinYangAurelius84thisjiangcxxly2742195759
authored
【Prim】Custom softmax grad (PaddlePaddle#51474)
* [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * Pr 50885 (#7) * [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * fix code in a dy2static-friendly way. * [dystatic] add hooker for prim --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com> * [prim] enable dygraph_to_static to support custom_vjp * fix cast prim and vjp dtype mapping error bug * Cxx prim custom vjp (#8) * [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * Pr 50885 (#7) * [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * fix code in a dy2static-friendly way. * [dystatic] add hooker for prim --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com> * [prim] enable dygraph_to_static to support custom_vjp * fix cast prim and vjp dtype mapping error bug * [dy2static-ci] fix dy2static ci errors. --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com> * [Prim] enable whitelist and blacklist for custom_vjp * support softmax grad * remove additional code * add test back --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com> Co-authored-by: xiongkun <807377414@qq.com>
1 parent 50df017 commit f124c86

File tree

4 files changed

+251
-0
lines changed

4 files changed

+251
-0
lines changed

paddle/fluid/operators/softmax_op.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ limitations under the License. */
1919
#include "paddle/fluid/framework/infershape_utils.h"
2020
#include "paddle/fluid/framework/op_registry.h"
2121
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
22+
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
23+
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
24+
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
2225
#include "paddle/phi/core/infermeta_utils.h"
2326
#include "paddle/phi/infermeta/backward.h"
2427
#include "paddle/phi/infermeta/unary.h"
@@ -156,6 +159,23 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpMaker<T> {
156159
}
157160
};
158161

162+
class SoftmaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
163+
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
164+
165+
public:
166+
void Apply() override {
167+
paddle::Tensor out = this->GetSingleForwardOutput("Out");
168+
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
169+
paddle::Tensor dx = this->GetSingleInputGrad("X");
170+
auto* dx_ptr = this->GetOutputPtr(&dx);
171+
std::string dx_name = this->GetOutputName(dx);
172+
int axis = static_cast<int>(this->Attr<int>("axis"));
173+
VLOG(6) << "Runing softmax_grad composite func";
174+
prim::softmax_grad<prim::DescTensor>(out, out_grad, axis, dx_ptr);
175+
this->RecoverOutputName(dx, dx_name);
176+
}
177+
};
178+
159179
DECLARE_INPLACE_OP_INFERER(SoftmaxInplaceInferer, {"X", "Out"});
160180

161181
} // namespace operators
@@ -172,6 +192,7 @@ REGISTER_OPERATOR(softmax,
172192
ops::SoftmaxOpInferVarType,
173193
ops::SoftmaxOpGradMaker<paddle::framework::OpDesc>,
174194
ops::SoftmaxOpGradMaker<paddle::imperative::OpBase>,
195+
ops::SoftmaxCompositeGradOpMaker,
175196
ops::SoftmaxInplaceInferer,
176197
SoftmaxInferShapeFunctor);
177198
DECLARE_INFER_SHAPE_FUNCTOR(softmax_grad,

paddle/fluid/prim/api/composite_backward/composite_backward_api.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,35 @@ using Tensor = paddle::Tensor;
3030
using IntArray = paddle::experimental::IntArrayBase<paddle::Tensor>;
3131
// This function should have as same signature as phi, which defined in
3232
// paddle/phi/api/backward/backward_api.h
33+
template <typename T>
34+
void softmax_grad(const Tensor& out,
35+
const Tensor& out_grad,
36+
int axis,
37+
Tensor* x_grad) {
38+
if (x_grad) {
39+
if (out_grad.dims().size() > 0) {
40+
if (axis >= 0) {
41+
auto new_out_grad = out_grad * out;
42+
auto tmp_x_grad = new_out_grad -
43+
out * sum<T>(new_out_grad, {axis}, out.dtype(), true);
44+
set_output<T>(tmp_x_grad, x_grad);
45+
} else {
46+
auto new_out_grad = out_grad * out;
47+
auto tmp_x_grad =
48+
new_out_grad - out * sum<T>(new_out_grad,
49+
{out.dims().size() + axis},
50+
out.dtype(),
51+
true);
52+
set_output<T>(tmp_x_grad, x_grad);
53+
}
54+
} else {
55+
set_output<T>(
56+
full<T>(phi::vectorize(out_grad.dims()), 0.0, out_grad.dtype()),
57+
x_grad);
58+
}
59+
}
60+
}
61+
3362
template <typename T>
3463
void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) {
3564
if (x_grad) {

paddle/phi/api/yaml/legacy_backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,7 @@
11441144
param : [out]
11451145
kernel :
11461146
func : softmax_grad
1147+
composite : softmax_grad(out, out_grad, axis, x_grad)
11471148

11481149
- backward_op : spectral_norm_grad
11491150
forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) -> Tensor(out)
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from utils import TOLERANCE
19+
20+
import paddle
21+
import paddle.nn.functional as F
22+
from paddle.fluid import core
23+
24+
25+
def generate_data(shape, dtype="float32"):
26+
np_data = np.random.random(shape).astype(dtype)
27+
return np_data
28+
29+
30+
class Attr:
31+
def __init__(self) -> None:
32+
self.dtype = None
33+
self.axis = -1
34+
self.shape = None
35+
36+
def set_dtype(self, dtype) -> None:
37+
self.dtype = dtype
38+
return
39+
40+
def set_axis(self, axis) -> None:
41+
self.axis = axis
42+
return
43+
44+
def set_shape(self, shape) -> None:
45+
self.shape = shape
46+
return
47+
48+
def get_rtol(self, flag):
49+
rtol = TOLERANCE[self.dtype][flag].get("rtol")
50+
return rtol
51+
52+
def get_atol(self, flag):
53+
atol = TOLERANCE[self.dtype][flag].get("atol")
54+
return atol
55+
56+
57+
attrs = Attr()
58+
59+
60+
def fn(x):
61+
return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype)
62+
63+
64+
def expect_grad(inputs):
65+
paddle.disable_static()
66+
inputs.stop_gradient = False
67+
res = fn(inputs)
68+
69+
gradients = paddle.grad(res, inputs)
70+
return gradients
71+
72+
73+
class TestCompositeSoftmax(unittest.TestCase):
74+
def setUp(self):
75+
self.dtypes = ["float32", "float64"]
76+
self.shapes = [[2, 3, 4], [2, 3]]
77+
self.axes = [-1, 0, 1]
78+
79+
def cal_composite_grad(self, inputs):
80+
paddle.enable_static()
81+
core._set_prim_forward_enabled(True)
82+
startup_program = paddle.static.Program()
83+
main_program = paddle.static.Program()
84+
with paddle.static.program_guard(main_program, startup_program):
85+
x = paddle.static.data(
86+
'x', shape=inputs.shape, dtype=str(inputs.dtype)
87+
)
88+
x.stop_gradient = False
89+
y = fn(x)
90+
blocks = main_program.blocks
91+
92+
fwd_ops = [op.type for op in blocks[0].ops]
93+
# Ensure that softmax in original block
94+
self.assertTrue('softmax' in fwd_ops)
95+
96+
paddle.incubate.autograd.primapi.to_prim(blocks)
97+
98+
fwd_ops_new = [op.type for op in blocks[0].ops]
99+
# Ensure that softmax is splitted into small ops
100+
self.assertTrue('softmax' not in fwd_ops_new)
101+
102+
z = paddle.static.gradients([y], x)
103+
fwd_ops_grad = [op.type for op in blocks[0].ops]
104+
# Ensure that softmax_grad not in grad block
105+
106+
self.assertTrue('softmax_grad' not in fwd_ops_grad)
107+
108+
exe = paddle.static.Executor()
109+
exe.run(startup_program)
110+
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
111+
paddle.disable_static()
112+
core._set_prim_forward_enabled(False)
113+
return res
114+
115+
def compare_backward(self):
116+
np_data = generate_data(attrs.shape)
117+
tensor_data = paddle.to_tensor(np_data)
118+
119+
expect = expect_grad(tensor_data)[0].numpy()
120+
actual = self.cal_composite_grad(np_data)[0]
121+
122+
assert expect.dtype == actual.dtype
123+
np.testing.assert_allclose(
124+
expect,
125+
actual,
126+
rtol=attrs.get_rtol("backward"),
127+
atol=attrs.get_atol("backward"),
128+
)
129+
130+
def test_backward(self):
131+
for i in self.axes:
132+
for j in self.dtypes:
133+
for t in self.shapes:
134+
attrs.set_axis(i)
135+
attrs.set_dtype(j)
136+
attrs.set_shape(t)
137+
self.compare_backward()
138+
139+
140+
class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
141+
"test composite softmax and prim backward"
142+
143+
def setUp(self):
144+
core._set_prim_backward_enabled(True)
145+
self.dtypes = ["float32", "float64"]
146+
self.shapes = [[], [2, 3, 4], [2, 3]]
147+
self.axes = [-1, 0, 1]
148+
149+
def cal_composite_grad(self, inputs):
150+
paddle.enable_static()
151+
core._set_prim_all_enabled(True)
152+
startup_program = paddle.static.Program()
153+
main_program = paddle.static.Program()
154+
with paddle.static.program_guard(main_program, startup_program):
155+
x = paddle.static.data(
156+
'x', shape=inputs.shape, dtype=str(inputs.dtype)
157+
)
158+
x.stop_gradient = False
159+
y = fn(x)
160+
blocks = main_program.blocks
161+
z = paddle.static.gradients([y], x)
162+
paddle.incubate.autograd.primapi.to_prim(blocks)
163+
164+
exe = paddle.static.Executor()
165+
exe.run(startup_program)
166+
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
167+
paddle.disable_static()
168+
core._set_prim_all_enabled(False)
169+
return res
170+
171+
def compare_backward(self):
172+
if not attrs.shape and attrs.axis not in [-1, 0]:
173+
# op softmax does not support both case
174+
return
175+
np_data = generate_data(attrs.shape)
176+
tensor_data = paddle.to_tensor(np_data)
177+
178+
expect = expect_grad(tensor_data)[0].numpy()
179+
actual = self.cal_composite_grad(np_data)[0]
180+
181+
assert expect.dtype == actual.dtype
182+
np.testing.assert_allclose(
183+
expect,
184+
actual,
185+
rtol=attrs.get_rtol("prim_backward"),
186+
atol=attrs.get_rtol("prim_backward"),
187+
)
188+
189+
def test_prim_backward(self):
190+
for i in self.axes:
191+
for j in self.dtypes:
192+
for t in self.shapes:
193+
attrs.set_axis(i)
194+
attrs.set_dtype(j)
195+
attrs.set_shape(t)
196+
self.compare_backward()
197+
198+
199+
if __name__ == '__main__':
200+
unittest.main()

0 commit comments

Comments
 (0)