Skip to content

Commit c642aa1

Browse files
add_triple_grad rules (PaddlePaddle#54164)
1 parent 94a56cc commit c642aa1

File tree

5 files changed

+210
-117
lines changed

5 files changed

+210
-117
lines changed

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
prim_white_list = [
6868
"matmul_double_grad",
6969
"subtract_double_grad",
70+
"add_triple_grad",
7071
"silu_double_grad",
7172
]
7273

paddle/fluid/operators/elementwise/elementwise_add_op.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,44 @@ class ElementwiseAddTripleGradMaker : public framework::SingleGradOpMaker<T> {
154154
}
155155
};
156156

157+
class ElementwiseAddCompositeTripleGradOpMaker
158+
: public prim::CompositeGradOpMakerBase {
159+
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
160+
161+
public:
162+
void Apply() override {
163+
// get input
164+
paddle::Tensor ddx = this->GetSingleForwardInput("DDX");
165+
paddle::Tensor ddy = this->GetSingleForwardInput("DDY");
166+
paddle::Tensor d_ddout = this->GetSingleOutputGrad("DDOut");
167+
168+
// get output
169+
paddle::Tensor grad_grad_x_t =
170+
this->GetSingleInputGrad(framework::GradVarName("DDX"));
171+
paddle::Tensor grad_grad_y_t =
172+
this->GetSingleInputGrad(framework::GradVarName("DDY"));
173+
// get attr
174+
int axis = static_cast<int>(this->Attr<int>("axis"));
175+
PADDLE_ENFORCE_EQ(
176+
axis,
177+
-1,
178+
phi::errors::InvalidArgument("We only support axis = -1 in composite "
179+
"add_triple_grad but we got: ",
180+
axis));
181+
182+
paddle::Tensor* grad_grad_x = this->GetOutputPtr(&grad_grad_x_t);
183+
std::string grad_grad_x_name = this->GetOutputName(grad_grad_x_t);
184+
paddle::Tensor* grad_grad_y = this->GetOutputPtr(&grad_grad_y_t);
185+
std::string grad_grad_y_name = this->GetOutputName(grad_grad_y_t);
186+
187+
VLOG(6) << "Runing add_triple_grad composite func";
188+
prim::add_triple_grad<prim::DescTensor>(
189+
ddx, ddy, d_ddout, axis, grad_grad_x, grad_grad_y);
190+
this->RecoverOutputName(grad_grad_x_t, grad_grad_x_name);
191+
this->RecoverOutputName(grad_grad_y_t, grad_grad_y_name);
192+
}
193+
};
194+
157195
} // namespace operators
158196
} // namespace paddle
159197

paddle/fluid/prim/api/composite_backward/composite_backward_api.h

Lines changed: 0 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -234,30 +234,6 @@ void subtract_grad(const Tensor& x,
234234
}
235235
}
236236

237-
template <typename T>
238-
void subtract_double_grad(const Tensor& y,
239-
const Tensor& grad_out,
240-
const paddle::optional<Tensor>& grad_x_grad,
241-
const paddle::optional<Tensor>& grad_y_grad,
242-
int axis,
243-
Tensor* grad_out_grad) {
244-
if (grad_out_grad) {
245-
// ddout = ddx - ddy
246-
if (!grad_x_grad && !grad_y_grad) {
247-
grad_out_grad = nullptr;
248-
} else {
249-
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
250-
if (grad_x_grad) {
251-
ddout = ddout + grad_x_grad.get();
252-
}
253-
if (grad_y_grad) {
254-
ddout = ddout - grad_y_grad.get();
255-
}
256-
set_output<T>(ddout, grad_out_grad);
257-
}
258-
}
259-
}
260-
261237
template <typename T>
262238
void add_grad(const Tensor& x,
263239
const Tensor& y,
@@ -300,30 +276,6 @@ void add_grad(const Tensor& x,
300276
}
301277
}
302278

303-
template <typename T>
304-
void add_double_grad(const Tensor& y,
305-
const Tensor& grad_out,
306-
const paddle::optional<Tensor>& grad_x_grad,
307-
const paddle::optional<Tensor>& grad_y_grad,
308-
int axis,
309-
Tensor* grad_out_grad) {
310-
if (grad_out_grad) {
311-
// ddout = ddx + ddy
312-
if (!grad_x_grad && !grad_y_grad) {
313-
grad_out_grad = nullptr;
314-
} else {
315-
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
316-
if (grad_x_grad) {
317-
ddout = ddout + grad_x_grad.get();
318-
}
319-
if (grad_y_grad) {
320-
ddout = ddout + grad_y_grad.get();
321-
}
322-
set_output<T>(ddout, grad_out_grad);
323-
}
324-
}
325-
}
326-
327279
template <typename T>
328280
void sum_grad(const Tensor& x,
329281
const Tensor& out_grad,
@@ -555,75 +507,6 @@ void multiply_grad(const Tensor& x,
555507
}
556508
}
557509

558-
template <typename T>
559-
void multiply_double_grad(const Tensor& x,
560-
const Tensor& y,
561-
const Tensor& grad_out,
562-
const paddle::optional<Tensor>& grad_x_grad,
563-
const paddle::optional<Tensor>& grad_y_grad,
564-
int axis,
565-
Tensor* x_grad,
566-
Tensor* y_grad,
567-
Tensor* grad_out_grad) {
568-
if (x_grad) {
569-
if (grad_y_grad) {
570-
auto dx = grad_y_grad.get() * grad_out;
571-
if (dx.dims() != x.dims()) {
572-
auto axes = get_reduce_dims_from_out(dx.dims(), x.dims());
573-
if (!axes.size()) {
574-
set_output<T>(dx, x_grad);
575-
} else {
576-
auto dx_reduce = dx.sum(phi::vectorize(axes), dx.dtype(), false);
577-
if (dx_reduce.dims().size() != x.dims().size()) {
578-
dx_reduce = reshape<T>(dx_reduce, x.shape());
579-
}
580-
set_output<T>(dx_reduce, x_grad);
581-
}
582-
} else {
583-
set_output<T>(dx, x_grad);
584-
}
585-
586-
} else {
587-
x_grad = nullptr;
588-
}
589-
}
590-
if (y_grad) {
591-
if (grad_x_grad) {
592-
auto dy = grad_x_grad.get() * grad_out;
593-
if (dy.dims() != y.dims()) {
594-
auto axes = get_reduce_dims_from_out(dy.dims(), y.dims());
595-
if (!axes.size()) {
596-
set_output<T>(dy, y_grad);
597-
} else {
598-
auto dy_reduce = dy.sum(phi::vectorize(axes), dy.dtype(), false);
599-
if (dy_reduce.dims().size() != y.dims().size()) {
600-
dy_reduce = reshape<T>(dy_reduce, y.shape());
601-
}
602-
set_output<T>(dy_reduce, y_grad);
603-
}
604-
} else {
605-
set_output<T>(dy, y_grad);
606-
}
607-
} else {
608-
y_grad = nullptr;
609-
}
610-
}
611-
if (grad_out_grad) {
612-
if (grad_x_grad && grad_y_grad) {
613-
auto ddout = grad_x_grad.get() * y + grad_y_grad.get() * x;
614-
set_output<T>(ddout, grad_out_grad);
615-
} else if (grad_x_grad) {
616-
auto ddout = grad_x_grad.get() * y;
617-
set_output<T>(ddout, grad_out_grad);
618-
} else if (grad_y_grad) {
619-
auto ddout = grad_y_grad.get() * x;
620-
set_output<T>(ddout, grad_out_grad);
621-
} else {
622-
grad_out_grad = nullptr;
623-
}
624-
}
625-
}
626-
627510
template <typename T>
628511
void expand_grad(const Tensor& x,
629512
const Tensor& out_grad,

paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,5 +383,175 @@ void silu_double_grad(const Tensor& x,
383383
}
384384
}
385385

386+
template <typename T>
387+
void multiply_double_grad(const Tensor& x,
388+
const Tensor& y,
389+
const Tensor& grad_out,
390+
const paddle::optional<Tensor>& grad_x_grad,
391+
const paddle::optional<Tensor>& grad_y_grad,
392+
int axis,
393+
Tensor* x_grad,
394+
Tensor* y_grad,
395+
Tensor* grad_out_grad) {
396+
if (x_grad) {
397+
if (grad_y_grad) {
398+
auto dx = grad_y_grad.get() * grad_out;
399+
if (dx.dims() != x.dims()) {
400+
auto axes = get_reduce_dims_from_out(dx.dims(), x.dims());
401+
if (!axes.size()) {
402+
set_output<T>(dx, x_grad);
403+
} else {
404+
auto dx_reduce = dx.sum(phi::vectorize(axes), dx.dtype(), false);
405+
if (dx_reduce.dims().size() != x.dims().size()) {
406+
dx_reduce = reshape<T>(dx_reduce, x.shape());
407+
}
408+
set_output<T>(dx_reduce, x_grad);
409+
}
410+
} else {
411+
set_output<T>(dx, x_grad);
412+
}
413+
414+
} else {
415+
x_grad = nullptr;
416+
}
417+
}
418+
if (y_grad) {
419+
if (grad_x_grad) {
420+
auto dy = grad_x_grad.get() * grad_out;
421+
if (dy.dims() != y.dims()) {
422+
auto axes = get_reduce_dims_from_out(dy.dims(), y.dims());
423+
if (!axes.size()) {
424+
set_output<T>(dy, y_grad);
425+
} else {
426+
auto dy_reduce = dy.sum(phi::vectorize(axes), dy.dtype(), false);
427+
if (dy_reduce.dims().size() != y.dims().size()) {
428+
dy_reduce = reshape<T>(dy_reduce, y.shape());
429+
}
430+
set_output<T>(dy_reduce, y_grad);
431+
}
432+
} else {
433+
set_output<T>(dy, y_grad);
434+
}
435+
} else {
436+
y_grad = nullptr;
437+
}
438+
}
439+
if (grad_out_grad) {
440+
if (grad_x_grad && grad_y_grad) {
441+
auto ddout = grad_x_grad.get() * y + grad_y_grad.get() * x;
442+
set_output<T>(ddout, grad_out_grad);
443+
} else if (grad_x_grad) {
444+
auto ddout = grad_x_grad.get() * y;
445+
set_output<T>(ddout, grad_out_grad);
446+
} else if (grad_y_grad) {
447+
auto ddout = grad_y_grad.get() * x;
448+
set_output<T>(ddout, grad_out_grad);
449+
} else {
450+
grad_out_grad = nullptr;
451+
}
452+
}
453+
}
454+
455+
template <typename T>
456+
void add_double_grad(const Tensor& y,
457+
const Tensor& grad_out,
458+
const paddle::optional<Tensor>& grad_x_grad,
459+
const paddle::optional<Tensor>& grad_y_grad,
460+
int axis,
461+
Tensor* grad_out_grad) {
462+
if (grad_out_grad) {
463+
// ddout = ddx + ddy
464+
if (!grad_x_grad && !grad_y_grad) {
465+
grad_out_grad = nullptr;
466+
} else {
467+
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
468+
if (grad_x_grad) {
469+
ddout = ddout + grad_x_grad.get();
470+
}
471+
if (grad_y_grad) {
472+
ddout = ddout + grad_y_grad.get();
473+
}
474+
set_output<T>(ddout, grad_out_grad);
475+
}
476+
}
477+
}
478+
479+
template <typename T>
480+
void add_triple_grad(const paddle::optional<Tensor>& grad_grad_x,
481+
const paddle::optional<Tensor>& grad_grad_y,
482+
const Tensor& grad_grad_out_grad,
483+
int axis,
484+
Tensor* grad_grad_x_grad,
485+
Tensor* grad_grad_y_grad) {
486+
if (grad_grad_y_grad) {
487+
if (grad_grad_y) {
488+
if (grad_grad_y.get().dims() != grad_grad_out_grad.dims()) {
489+
// Maybe need reduce here
490+
phi::DDim reduce_dim = get_reduce_dims(grad_grad_y.get().dims(),
491+
grad_grad_out_grad.dims());
492+
if (!reduce_dim.size()) {
493+
by_pass<T>(grad_grad_out_grad, grad_grad_y_grad);
494+
} else {
495+
auto dddy_reduce_res = grad_grad_out_grad.sum(
496+
phi::vectorize(reduce_dim), grad_grad_y.get().dtype(), false);
497+
auto dddy_tmp = reshape<T>(dddy_reduce_res,
498+
phi::vectorize(grad_grad_y.get().dims()));
499+
set_output<T>(dddy_tmp, grad_grad_y_grad);
500+
}
501+
} else {
502+
by_pass<T>(grad_grad_out_grad, grad_grad_y_grad);
503+
}
504+
} else {
505+
grad_grad_y_grad = nullptr;
506+
}
507+
}
508+
if (grad_grad_x_grad) {
509+
if (grad_grad_x) {
510+
if (grad_grad_x.get().dims() != grad_grad_out_grad.dims()) {
511+
// Maybe need reduce here
512+
auto reduce_dim = get_reduce_dims(grad_grad_x.get().dims(),
513+
grad_grad_out_grad.dims());
514+
if (!reduce_dim.size()) {
515+
by_pass<T>(grad_grad_out_grad, grad_grad_x_grad);
516+
} else {
517+
auto dddx_reduce_res = grad_grad_out_grad.sum(
518+
phi::vectorize(reduce_dim), grad_grad_x.get().dtype(), false);
519+
auto dddx_tmp = reshape<T>(dddx_reduce_res,
520+
phi::vectorize(grad_grad_x.get().dims()));
521+
set_output<T>(dddx_tmp, grad_grad_x_grad);
522+
}
523+
} else {
524+
by_pass<T>(grad_grad_out_grad, grad_grad_x_grad);
525+
}
526+
} else {
527+
grad_grad_x_grad = nullptr;
528+
}
529+
}
530+
}
531+
532+
template <typename T>
533+
void subtract_double_grad(const Tensor& y,
534+
const Tensor& grad_out,
535+
const paddle::optional<Tensor>& grad_x_grad,
536+
const paddle::optional<Tensor>& grad_y_grad,
537+
int axis,
538+
Tensor* grad_out_grad) {
539+
if (grad_out_grad) {
540+
// ddout = ddx - ddy
541+
if (!grad_x_grad && !grad_y_grad) {
542+
grad_out_grad = nullptr;
543+
} else {
544+
Tensor ddout = full<T>(phi::vectorize(grad_out.dims()), 0.0, y.dtype());
545+
if (grad_x_grad) {
546+
ddout = ddout + grad_x_grad.get();
547+
}
548+
if (grad_y_grad) {
549+
ddout = ddout - grad_y_grad.get();
550+
}
551+
set_output<T>(ddout, grad_out_grad);
552+
}
553+
}
554+
}
555+
386556
} // namespace prim
387557
} // namespace paddle

paddle/phi/api/yaml/legacy_backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
kernel :
3737
func : add_triple_grad
3838
inplace : (grad_grad_out_grad -> grad_grad_x_grad)
39+
composite : add_triple_grad (grad_grad_x, grad_grad_y, grad_grad_out_grad, axis, grad_grad_x_grad, grad_grad_y_grad )
3940

4041
- backward_op : amax_grad
4142
forward: amax (Tensor x, int64_t[] axis={}, bool keepdim=false) -> Tensor(out)

0 commit comments

Comments
 (0)