Skip to content

Add back_decomp and support dynamic shape for amax_grad and amin_grad #68818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/primitive/codegen/decomp_vjp_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@
PRIM_VJP = UNARY_PRIM_VJP_OPS + BINARY_PRIM_VJP_OPS + OTHER_PRIM_VJP_OPS

CUSTOM_VJP = [
'amax_grad',
'amin_grad',
'bce_loss_grad',
'batch_norm_grad',
'dropout_grad',
Expand Down
118 changes: 118 additions & 0 deletions paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -3329,6 +3329,124 @@ void ceil_grad(const Tensor& out_grad, Tensor* x_grad) {
}
}

template <typename T>
void amax_grad(const Tensor& x,
const Tensor& out,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
Tensor* x_grad) {
if (x_grad) {
Tensor x_grad_tmp;
if (has_dynamic_shape(x.shape())) {
const Tensor x_shape = shape64<T>(x);
const Tensor zero_tensor =
backend::full_with_tensor<T>(x_shape, 0.0, x.dtype());
const int64_t axis_size = axis.size();
const int64_t x_dim_size = x.dims().size();

reduce_all = false;
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
reduce_all = true;
}

if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
auto out_grad_tmp = backend::expand<T>(out_grad, x_shape);
auto out_tmp = backend::expand<T>(out, x_shape);
auto mask = equal<T>(x, out_tmp);
auto mask_sum = backend::sum<T>(mask, axis, x.dtype(), keepdim = true);
auto grad_tmp = out_grad_tmp / mask_sum;
x_grad_tmp = where<T>(mask, grad_tmp, zero_tensor);
} else {
const Tensor out_grad_shape = shape64<T>(out_grad);
auto axis_ = std::vector<int64_t>();

if (reduce_all) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}
const Tensor out_grad_shape_extend =
get_unsqueeze_dims<T>(out_grad_shape, axis_);
auto out_grad_ = backend::reshape<T>(out_grad, out_grad_shape_extend);
auto out_ = backend::reshape<T>(out, out_grad_shape_extend);
auto out_grad_tmp = backend::expand<T>(out_grad_, x_shape);
auto out_tmp = backend::expand<T>(out_, x_shape);
auto mask = equal<T>(x, out_tmp);
auto mask_sum = backend::sum<T>(mask, axis_, x.dtype(), keepdim = true);
auto grad_tmp = out_grad_tmp / mask_sum;
x_grad_tmp = where<T>(mask, grad_tmp, zero_tensor);
}
} else {
auto zero_tensor = full<T>(common::vectorize(x.dims()), 0.0, x.dtype());
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims());
int64_t axis_size = axis.size();
int64_t x_dim_size = x_dim.size();
reduce_all = false;
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
reduce_all = true;
}

if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
auto out_grad_tmp = out_grad.expand(IntArray(x_dim));
auto out_tmp = out.expand(IntArray(x_dim));
auto mask = equal<T>(x, out_tmp);
auto mask_sum = sum<T>(mask, axis, x.dtype(), keepdim = true);
auto grad_tmp = out_grad_tmp / mask_sum;
x_grad_tmp = where<T>(mask, grad_tmp, zero_tensor);
} else {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
auto out_ = reshape<T>(out, out_grad_shape);
auto out_grad_tmp = out_grad_.expand(IntArray(x_dim));
auto out_tmp = out_.expand(IntArray(x_dim));
auto mask = equal<T>(x, out_tmp);
auto mask_sum = sum<T>(mask, axis_, x.dtype(), keepdim = true);
auto grad_tmp = out_grad_tmp / mask_sum;
x_grad_tmp = where<T>(mask, grad_tmp, zero_tensor);
}
}
set_output<T>(x_grad_tmp, x_grad);
}
}

template <typename T>
void amin_grad(const Tensor& x,
const Tensor& out,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
Tensor* x_grad) {
if (x_grad) {
Tensor x_grad_tmp;
amax_grad<T>(x, out, out_grad, axis, keepdim, reduce_all, &x_grad_tmp);

set_output<T>(x_grad_tmp, x_grad);
}
}

} // namespace details
} // namespace primitive
} // namespace paddle
2 changes: 2 additions & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
ALLOW_DYNAMIC_SHAPE_VJP_OPS = [
"pd_op.abs",
"pd_op.add",
"pd_op.amax",
"pd_op.amin",
"pd_op.argsort",
"pd_op.assign",
"pd_op.batch_norm_",
Expand Down
3 changes: 2 additions & 1 deletion test/prim/pir_prim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ set(TEST_PRIM_PURE_PIR_CASES
test_decomp_whole_program
test_dynamic_combine1
test_dynamic_combine2
test_decomp_fallback)
test_decomp_fallback
test_prim_amax_amin_op)

foreach(target ${TEST_PRIM_PURE_PIR_CASES})
py_test_modules(
Expand Down
Loading