Skip to content

Commit 802a81d

Browse files
xiaoguoguo626807xysheng-baiduAurelius84thisjiangcxxly
authored
【prim】New layer_norm grad (PaddlePaddle#51750)
* Add flatten composite rule * get the right xshape and pass func test * add cinn unit test * Remove cinn test, wait for it to be added after repair * add comp test to test_flatten_contiguous_range_op.py * remove func test on composite_ops * Add comments to maybe_wrap_dim func * remove commented code * fix the problem with 0D tensor case * add flatten split rule comment * fix syntax issues * block flatten on resnet_prim_cinn * init change * tmp commit * add layer_norm InferMeta check * cast type modify * [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * Pr 50885 (#7) * [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * fix code in a dy2static-friendly way. * [dystatic] add hooker for prim --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com> * [prim] enable dygraph_to_static to support custom_vjp * fix cast prim and vjp dtype mapping error bug * recover * big tol * [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * Pr 50885 (#7) * [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * fix code in a dy2static-friendly way. * [dystatic] add hooker for prim --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com> * [prim] enable dygraph_to_static to support custom_vjp * fix cast prim and vjp dtype mapping error bug * Cxx prim custom vjp (#8) * [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * Pr 50885 (#7) * [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * fix code in a dy2static-friendly way. * [dystatic] add hooker for prim --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com> * [prim] enable dygraph_to_static to support custom_vjp * fix cast prim and vjp dtype mapping error bug * [dy2static-ci] fix dy2static ci errors. --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com> * [Prim] enable whitelist and blacklist for custom_vjp * debug log * clear log * fix * nothing * less memory * recover utils * fix * modify threshold value * skip layer_norm for test_bert * back to bert success state * add epsion * delete unnecessary compute * modify amp dtype * modify * order * delete sqrt check and fp16 --------- Co-authored-by: xuyongsheng <xuyongsheng@baidu.com> Co-authored-by: xysheng-baidu <121540080+xysheng-baidu@users.noreply.github.com> Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com> Co-authored-by: xiongkun <807377414@qq.com>
1 parent b81188f commit 802a81d

File tree

9 files changed

+330
-49
lines changed

9 files changed

+330
-49
lines changed

paddle/fluid/operators/layer_norm_op.cc

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@ limitations under the License. */
1515
#include <memory>
1616
#include <string>
1717

18+
#include "paddle/fluid/framework/infershape_utils.h"
1819
#include "paddle/fluid/framework/op_registry.h"
20+
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
21+
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
22+
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
23+
#include "paddle/phi/core/infermeta_utils.h"
24+
#include "paddle/phi/infermeta/ternary.h"
1925

2026
namespace paddle {
2127
namespace operators {
@@ -253,15 +259,78 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
253259
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInferer,
254260
"Bias");
255261

262+
class LayerNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
263+
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
264+
265+
public:
266+
void Apply() override {
267+
// get inputs
268+
paddle::Tensor x = this->GetSingleForwardInput("X");
269+
paddle::Tensor mean = this->GetSingleForwardOutput("Mean");
270+
paddle::Tensor var = this->GetSingleForwardOutput("Variance");
271+
paddle::Tensor y_grad = this->GetSingleOutputGrad("Y");
272+
paddle::optional<paddle::Tensor> scale =
273+
this->GetOptionalSingleForwardInput("Scale");
274+
paddle::optional<paddle::Tensor> bias =
275+
this->GetOptionalSingleForwardInput("Bias");
276+
277+
// get Attrs
278+
auto epsilon = this->Attr<float>("epsilon");
279+
auto begin_norm_axis = this->Attr<int>("begin_norm_axis");
280+
281+
// get outputs
282+
paddle::Tensor x_grad = this->GetSingleInputGrad("X");
283+
paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale");
284+
paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias");
285+
286+
auto dx_ptr = this->GetOutputPtr(&x_grad);
287+
std::string dx_name = this->GetOutputName(x_grad);
288+
auto dscale_ptr = this->GetOutputPtr(&scale_grad);
289+
std::string dscale_name = this->GetOutputName(scale_grad);
290+
auto dbias_ptr = this->GetOutputPtr(&bias_grad);
291+
std::string dbias_name = this->GetOutputName(bias_grad);
292+
293+
VLOG(6) << "Runing layer_norm_grad composite func";
294+
prim::layer_norm_grad<prim::DescTensor>(x,
295+
scale,
296+
bias,
297+
mean,
298+
var,
299+
y_grad,
300+
epsilon,
301+
begin_norm_axis,
302+
dx_ptr,
303+
dscale_ptr,
304+
dbias_ptr);
305+
306+
this->RecoverOutputName(x_grad, dx_name);
307+
this->RecoverOutputName(scale_grad, dscale_name);
308+
this->RecoverOutputName(bias_grad, dbias_name);
309+
}
310+
};
311+
256312
} // namespace operators
257313
} // namespace paddle
258314

259315
namespace ops = paddle::operators;
316+
317+
DECLARE_INFER_SHAPE_FUNCTOR(layer_norm,
318+
LayerNormInferShapeFunctor,
319+
PD_INFER_META(phi::LayerNormInferMeta));
320+
260321
REGISTER_OPERATOR(layer_norm,
261322
ops::LayerNormOp,
262323
ops::LayerNormOpMaker,
263324
ops::LayerNormGradOpMaker<paddle::framework::OpDesc>,
264-
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>);
325+
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>,
326+
ops::LayerNormCompositeGradOpMaker,
327+
LayerNormInferShapeFunctor);
328+
329+
DECLARE_INFER_SHAPE_FUNCTOR(layer_norm_grad,
330+
LayerNormGradInferShapeFunctor,
331+
PD_INFER_META(phi::LayerNormGradInferMeta));
332+
265333
REGISTER_OPERATOR(layer_norm_grad,
266334
ops::LayerNormGradOp,
267-
ops::LayerNormGradNoNeedBufferVarInferer);
335+
ops::LayerNormGradNoNeedBufferVarInferer,
336+
LayerNormGradInferShapeFunctor);

paddle/fluid/prim/api/api.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
- tile
3636
- transpose
3737
- pad
38+
- sqrt
3839
- cumsum
3940
- put_along_axis
4041
- greater_than

paddle/fluid/prim/api/composite_backward/composite_backward_api.h

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,101 @@ void slice_grad(const Tensor& input,
896896
}
897897
}
898898

899+
template <typename T>
900+
void layer_norm_grad(const Tensor& x,
901+
const paddle::optional<Tensor>& scale,
902+
const paddle::optional<Tensor>& bias,
903+
const Tensor& mean,
904+
const Tensor& variance,
905+
const Tensor& out_grad,
906+
float epsilon,
907+
int begin_norm_axis,
908+
Tensor* x_grad,
909+
Tensor* scale_grad,
910+
Tensor* bias_grad) {
911+
auto x_dims = x.dims();
912+
auto shape_1 = 1; // front part
913+
auto shape_2 = 1; // back part
914+
for (int i = 0; i < begin_norm_axis; ++i) {
915+
shape_1 *= x_dims[i];
916+
}
917+
for (int i = begin_norm_axis; i < x.dims().size(); ++i) {
918+
shape_2 *= x_dims[i];
919+
}
920+
auto scale_ptr = scale.get_ptr();
921+
auto bias_ptr = bias.get_ptr();
922+
923+
// cast dtype to float32 if dtype =float16
924+
Tensor x_cast = x;
925+
Tensor out_grad_cast = out_grad;
926+
Tensor scale_cast;
927+
if (scale_ptr) {
928+
scale_cast = reshape<T>(*scale_ptr, std::vector<int64_t>({1, shape_2}));
929+
}
930+
if (x.dtype() == phi::DataType::FLOAT16) {
931+
x_cast = cast<T>(x, phi::DataType::FLOAT32);
932+
out_grad_cast = cast<T>(out_grad, phi::DataType::FLOAT32);
933+
if (scale_ptr) {
934+
scale_cast = cast<T>(scale_cast, phi::DataType::FLOAT32);
935+
}
936+
}
937+
938+
x_cast = reshape<T>(x_cast, std::vector<int64_t>({shape_1, shape_2}));
939+
out_grad_cast =
940+
reshape<T>(out_grad_cast, std::vector<int64_t>({shape_1, shape_2}));
941+
auto mean_ = reshape<T>(mean, std::vector<int64_t>({shape_1, 1}));
942+
auto variance_ = reshape<T>(variance, std::vector<int64_t>({shape_1, 1}));
943+
if (bias_grad) {
944+
if (bias_ptr) {
945+
auto bias_grad_tmp =
946+
out_grad_cast.sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
947+
bias_grad_tmp = reshape<T>(bias_grad_tmp, bias_ptr->shape());
948+
set_output<T>(bias_grad_tmp, bias_grad);
949+
} else {
950+
bias_grad = nullptr;
951+
}
952+
}
953+
auto x_sub_mean = x_cast - mean_;
954+
auto tmp = (1.0 / (variance_ + epsilon));
955+
auto sqrt_var_1 = sqrt<T>(tmp);
956+
if (scale_grad) {
957+
if (scale_ptr) {
958+
auto scale_grad_tmp =
959+
(x_sub_mean * sqrt_var_1 * out_grad_cast)
960+
.sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
961+
scale_grad_tmp = reshape<T>(scale_grad_tmp, scale_ptr->shape());
962+
set_output<T>(scale_grad_tmp, scale_grad);
963+
} else {
964+
scale_grad = nullptr;
965+
}
966+
}
967+
968+
if (x_grad) {
969+
if (!scale_ptr) {
970+
scale_cast =
971+
full<T>(std::vector<int64_t>({1, shape_2}), 1.0, x_cast.dtype());
972+
}
973+
auto out_grad_scale = out_grad_cast * scale_cast;
974+
auto dx_end = (sqrt_var_1 * out_grad_scale);
975+
auto d_mean_0 =
976+
(-dx_end).sum(std::vector<int64_t>({1}), x_cast.dtype(), true);
977+
auto d_mean = (1.0 / shape_2) * d_mean_0;
978+
auto d_std_1 = (-tmp * x_sub_mean * out_grad_scale)
979+
.sum(std::vector<int64_t>({1}), x_cast.dtype(), true);
980+
auto d_std_2 = (1.0 / shape_2) * sqrt_var_1;
981+
d_std_2 = reshape<T>(d_std_2, std::vector<int64_t>({shape_1, 1}));
982+
d_std_2 = d_std_2 * x_sub_mean;
983+
auto d_std = d_std_1 * d_std_2;
984+
985+
auto x_grad_tmp = dx_end + d_mean + d_std;
986+
x_grad_tmp = reshape<T>(x_grad_tmp, phi::vectorize(x.dims()));
987+
if (x.dtype() == phi::DataType::FLOAT16) {
988+
x_grad_tmp = cast<T>(x_grad_tmp, x.dtype());
989+
}
990+
set_output<T>(x_grad_tmp, x_grad);
991+
}
992+
}
993+
899994
template <typename T>
900995
void cumsum_grad(const Tensor& x,
901996
const Tensor& out_grad,

paddle/phi/api/yaml/legacy_backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@
629629
kernel :
630630
func : layer_norm_grad
631631
data_type : out_grad
632+
composite : layer_norm_grad(x, scale, bias, mean,varience, out_grad, epsilon, begin_norm_axis, x_grad, scale_grad, bias_grad)
632633
no_need_buffer : bias
633634
optional : scale, bias
634635

paddle/phi/infermeta/ternary.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,14 +574,23 @@ void LayerNormInferMeta(const MetaTensor& x,
574574
right));
575575
}
576576

577+
phi::DataType x_dtype = x.dtype();
577578
out->set_dims(x_dim);
579+
out->set_dtype(x_dtype);
580+
out->share_lod(x);
581+
582+
phi::DataType param_type =
583+
(x_dtype == phi::DataType::BFLOAT16 || x_dtype == phi::DataType::FLOAT16)
584+
? phi::DataType::FLOAT32
585+
: x_dtype;
578586
if (mean) {
579587
mean->set_dims({left});
588+
mean->set_dtype(param_type);
580589
}
581590
if (variance) {
582591
variance->set_dims({left});
592+
variance->set_dtype(param_type);
583593
}
584-
out->share_lod(x);
585594
}
586595

587596
void LayerNormGradInferMeta(const MetaTensor& x,

python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,12 @@ def test_train(self):
237237

238238
def test_train_composite(self):
239239
core._set_prim_backward_enabled(True)
240+
# core._add_skip_comp_ops("layer_norm")
240241
static_loss, static_ppl = self.train_static(
241242
self.bert_config, self.data_reader
242243
)
243244
core._set_prim_backward_enabled(False)
245+
# core._add_skip_comp_ops("layer_norm")
244246
dygraph_loss, dygraph_ppl = self.train_dygraph(
245247
self.bert_config, self.data_reader
246248
)

0 commit comments

Comments
 (0)