Skip to content

Commit

Permalink
[Prim]Support dynamic shape for sigmoid_grad decomp (#64750)
Browse files Browse the repository at this point in the history
* [Prim]Support dynamic shape for sigmoid_grad decomp

* fix conflict

* fix typo
  • Loading branch information
Aurelius84 authored May 31, 2024
1 parent 6e0b844 commit f693970
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 50 deletions.
92 changes: 43 additions & 49 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ namespace paddle {
namespace primitive {
namespace details {

// empty_shape means x.shape=[]
static std::vector<int64_t> empty_shape;

template <typename T>
static Tensor get_slice(const Tensor& x, int64_t idx) {
return slice<T>(x, {0}, {idx}, {idx + 1}, {1}, {});
Expand Down Expand Up @@ -98,7 +95,7 @@ Tensor mean_decomp(const Tensor& x, const IntArray& axis, bool keepdim) {
for (size_t i = 0; i < axis_.size(); i++) {
value_ *= x_dim[axis_[i]];
}
value = full<T>(empty_shape, value_, sum_x.dtype());
value = full_scalar<T>(value_, sum_x.dtype());
}

Tensor res = sum_x / value;
Expand Down Expand Up @@ -148,7 +145,7 @@ Tensor p_norm_decomp(const Tensor& x,
Tensor res;
if (porder == 0.0) {
// 0-norm
auto zero = full<T>(empty_shape, 0, x_tmp.dtype());
auto zero = full_scalar<T>(0, x_tmp.dtype());
auto none_zero = not_equal<T>(x_tmp, zero);
res = cast<T>(none_zero, x_tmp.dtype());
res = sum<T>(res, {axis}, x_tmp.dtype(), keepdim);
Expand All @@ -169,8 +166,8 @@ Tensor p_norm_decomp(const Tensor& x,
res = min<T>(x_tmp, {axis}, keepdim);
} else {
// vanilla p-norm
auto porder_tensor = full<T>(empty_shape, porder, x_tmp.dtype());
auto inv_porder_tensor = full<T>(empty_shape, 1 / porder, x_tmp.dtype());
auto porder_tensor = full_scalar<T>(porder, x_tmp.dtype());
auto inv_porder_tensor = full_scalar<T>(1 / porder, x_tmp.dtype());
res = elementwise_pow<T>(x_tmp, porder_tensor);
res = sum<T>(res, {axis}, x_tmp.dtype(), keepdim);
res = elementwise_pow<T>(res, inv_porder_tensor);
Expand All @@ -194,8 +191,7 @@ Tensor pow_decomp(const Tensor& x, const paddle::Scalar& y) {
}

check_valid_type(y.dtype());
Tensor y_full = full<T>(empty_shape, y, x_cast.dtype());

Tensor y_full = full_scalar<T>(y, x_cast.dtype());
auto ans = elementwise_pow<T>(x_cast, y_full);
if (need_cast) {
return cast<T>(ans, org_dtype);
Expand Down Expand Up @@ -282,13 +278,13 @@ Tensor squared_l2_norm_decomp(const Tensor& x) {

template <typename T>
Tensor reciprocal_decomp(const Tensor& x) {
return full<T>(empty_shape, 1.0, x.dtype()) / x;
return full_scalar<T>(1.0, x.dtype()) / x;
}

template <typename T>
Tensor bce_loss_decomp(const Tensor& x, const Tensor& label) {
auto one = full<T>(empty_shape, 1, x.dtype());
auto ans = full<T>(empty_shape, -1, x.dtype()) *
auto one = full_scalar<T>(1, x.dtype());
auto ans = full_scalar<T>(-1, x.dtype()) *
(label * log<T>(x) + (one - label) * log<T>(one - x));
return ans;
}
Expand Down Expand Up @@ -382,7 +378,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_decomp(
}
}

Tensor half = full<T>(empty_shape, -0.5, x_cast.dtype());
Tensor half = full_scalar<T>(-0.5, x_cast.dtype());

bool use_run_stat = (is_test && (!trainable_statistics)) || use_global_stats;
Tensor x_hat;
Expand Down Expand Up @@ -421,9 +417,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_decomp(
run_var_ = assign<T>(run_var);
}
Tensor y;
Tensor new_scale =
scale ? scale.get() : full<T>(empty_shape, 1, x_cast.dtype());
Tensor new_bias = bias ? bias.get() : full<T>(empty_shape, 0, x_cast.dtype());
Tensor new_scale = scale ? scale.get() : full_scalar<T>(1, x_cast.dtype());
Tensor new_bias = bias ? bias.get() : full_scalar<T>(0, x_cast.dtype());
if (data_layout_ == DataLayout::kNHWC) {
y = x_hat * new_scale + new_bias;
} else {
Expand Down Expand Up @@ -539,13 +534,13 @@ Tensor swiglu_decomp(const Tensor& x, const paddle::optional<Tensor>& y) {

template <typename T>
Tensor relu_decomp(const Tensor& x) {
return maximum<T>(x, full<T>(empty_shape, 0.0, x.dtype()));
return maximum<T>(x, full_scalar<T>(0.0, x.dtype()));
}

template <typename T>
Tensor relu6_decomp(const Tensor& x) {
auto tmp = maximum<T>(x, full<T>(empty_shape, 0.0, x.dtype()));
auto res = minimum<T>(tmp, full<T>(empty_shape, 6.0, x.dtype()));
auto tmp = maximum<T>(x, full_scalar<T>(0.0, x.dtype()));
auto res = minimum<T>(tmp, full_scalar<T>(6.0, x.dtype()));
return res;
}

Expand Down Expand Up @@ -653,7 +648,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_decomp(
auto difference = x_cast - mean_;
auto var_tmp1 = difference * difference;
auto variance = mean_decomp<T>(var_tmp1, axis, true);
auto var_tmp3 = variance + full<T>(empty_shape, epsilon, variance.dtype());
auto var_tmp3 = variance + full_scalar<T>(epsilon, variance.dtype());
auto rsqrt_var = rsqrt<T>(var_tmp3);
auto out = difference * rsqrt_var;

Expand Down Expand Up @@ -798,18 +793,18 @@ std::tuple<Tensor, Tensor> dropout_decomp(
Tensor uniform_tensor;
if (has_dynamic_shape(x.shape())) {
auto shape_tensor = shape<T>(x);
auto zero = full<T>(empty_shape, 0.0, dtype_tmp);
auto one = full<T>(empty_shape, 1.0, dtype_tmp);
auto zero = full_scalar<T>(0.0, dtype_tmp);
auto one = full_scalar<T>(1.0, dtype_tmp);
uniform_tensor =
backend::uniform<T>(shape_tensor, zero, one, dtype_tmp, seed_tmp);
} else {
uniform_tensor =
uniform<T>(phi::vectorize(x.dims()), dtype_tmp, 0.0, 1.0, seed_tmp);
}
auto mask = cast<T>(
greater_equal<T>(uniform_tensor, full<T>(empty_shape, p, dtype_tmp)),
org_dtype);
auto ones_p = full<T>(empty_shape, 1.0 - p.to<float>(), org_dtype);
auto mask =
cast<T>(greater_equal<T>(uniform_tensor, full_scalar<T>(p, dtype_tmp)),
org_dtype);
auto ones_p = full_scalar<T>(1.0 - p.to<float>(), org_dtype);
if (upscale_in_train) {
if (is_test) {
// inference: out = input
Expand All @@ -818,7 +813,7 @@ std::tuple<Tensor, Tensor> dropout_decomp(
// train: out = input * mask / ( 1.0 - p )
if (p.to<float>() == 1.0) {
// Process p=1. for avoid divide zero error (x*mask/(1.0-p))
auto zero = full<T>(empty_shape, 0.0, org_dtype);
auto zero = full_scalar<T>(0.0, org_dtype);
return std::make_tuple(x * zero, cast<T>(zero, DataType::UINT8));
} else {
auto ans = (x * mask) / ones_p;
Expand All @@ -842,20 +837,20 @@ Tensor gelu_decomp(const Tensor& x, bool approximate) {
const double PM_SQRT1_2 = 0.70710678118654752440; /* 1/sqrt(2) */

auto org_dtype = x.dtype();
auto half = full<T>(empty_shape, 0.5, org_dtype);
auto one = full<T>(empty_shape, 1.0, org_dtype);
auto half = full_scalar<T>(0.5, org_dtype);
auto one = full_scalar<T>(1.0, org_dtype);
if (approximate) {
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
auto kAlpha = full<T>(empty_shape, PM_2_SQRTPI * PM_SQRT1_2, org_dtype);
auto GELU_CONSTANT = full<T>(empty_shape, 0.044715, org_dtype);
auto x_pow3 = elementwise_pow<T>(x, full<T>(empty_shape, 3, org_dtype));
auto kAlpha = full_scalar<T>(PM_2_SQRTPI * PM_SQRT1_2, org_dtype);
auto GELU_CONSTANT = full_scalar<T>(0.044715, org_dtype);
auto x_pow3 = elementwise_pow<T>(x, full_scalar<T>(3, org_dtype));
auto tanh_out = tanh<T>(kAlpha * (x + x_pow3 * GELU_CONSTANT));

auto res = x * half * (one + tanh_out);
return res;
} else {
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
auto M_SQRT1_2T = full<T>(empty_shape, PM_SQRT1_2, org_dtype);
auto M_SQRT1_2T = full_scalar<T>(PM_SQRT1_2, org_dtype);
auto erf_out = one + erf<T>(x * M_SQRT1_2T);

auto res = x * half * erf_out;
Expand All @@ -867,10 +862,10 @@ template <typename T>
Tensor hardsigmoid_decomp(const Tensor& x, float slope, float offset) {
const double MAX_VALUE = 1.0;
const double MIN_VALUE = 0.0;
return maximum<T>(minimum<T>(x * full<T>(empty_shape, slope, x.dtype()) +
full<T>(empty_shape, offset, x.dtype()),
full<T>(empty_shape, MAX_VALUE, x.dtype())),
full<T>(empty_shape, MIN_VALUE, x.dtype()));
return maximum<T>(minimum<T>(x * full_scalar<T>(slope, x.dtype()) +
full_scalar<T>(offset, x.dtype()),
full_scalar<T>(MAX_VALUE, x.dtype())),
full_scalar<T>(MIN_VALUE, x.dtype()));
}

template <typename T>
Expand All @@ -881,15 +876,15 @@ Tensor hardswish_decomp(const Tensor& x) {

// out = minimum(maximum(x + offset, 0), threshold) * x / scale
auto minimum_out =
minimum<T>(maximum<T>(x + full<T>(empty_shape, OFFSET, x.dtype()),
full<T>(empty_shape, 0.0, x.dtype())),
full<T>(empty_shape, THRESHOLD, x.dtype()));
return (minimum_out * x) / full<T>(empty_shape, SCALE, x.dtype());
minimum<T>(maximum<T>(x + full_scalar<T>(OFFSET, x.dtype()),
full_scalar<T>(0.0, x.dtype())),
full_scalar<T>(THRESHOLD, x.dtype()));
return (minimum_out * x) / full_scalar<T>(SCALE, x.dtype());
}

template <typename T>
Tensor leaky_relu_decomp(const Tensor& x, float negative_slope) {
auto multiply_tmp = full<T>(empty_shape, negative_slope, x.dtype()) * x;
auto multiply_tmp = full_scalar<T>(negative_slope, x.dtype()) * x;
if (negative_slope < 1.0) {
return maximum<T>(x, multiply_tmp);
} else {
Expand Down Expand Up @@ -1127,8 +1122,7 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp(
var_ = maximum<T>(
var_tmp_,
backend::full_with_tensor<T>(shape<T>(var_tmp_), 0, var_tmp_.dtype()));
Tensor var_inv =
rsqrt<T>(var_ + full<T>(empty_shape, epsilon, var_.dtype()));
Tensor var_inv = rsqrt<T>(var_ + full_scalar<T>(epsilon, var_.dtype()));
Tensor res = (x_cast - mean_) * var_inv;
out = backend::reshape<T>(res, x_dim_t);
} else {
Expand All @@ -1143,7 +1137,7 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp(
auto var_tmp_ =
mean_decomp<T>(x_cast * x_cast, c_axis, true) - mean_ * mean_;
var_ = maximum<T>(var_tmp_, full<T>(var_tmp_.shape(), 0, var_tmp_.dtype()));
auto var_inv = rsqrt<T>(var_ + full<T>(empty_shape, epsilon, var_.dtype()));
auto var_inv = rsqrt<T>(var_ + full_scalar<T>(epsilon, var_.dtype()));
auto res = (x_cast - mean_) * var_inv;
out = reshape<T>(res, x_dim);
}
Expand Down Expand Up @@ -1207,7 +1201,7 @@ Tensor square_decomp(const Tensor& x) {
}

Tensor two;
two = full<T>(empty_shape, 2, x_cast.dtype());
two = full_scalar<T>(2, x_cast.dtype());

auto ans = elementwise_pow<T>(x_cast, two);
if (need_cast) {
Expand Down Expand Up @@ -1247,7 +1241,7 @@ Tensor sigmoid_cross_entropy_with_logits_decomp(
const Tensor tmp_norm = sum<T>(where<T>(abs<T>(diff) > eps1, one, zero));
// Follow the implementation in
// paddle/phi/kernels/cpu/sigmoid_cross_entropy_with_logits_kernel.cc
const Tensor eps2 = full<T>(empty_shape, 1e-5, x.type());
const Tensor eps2 = full_scalar<T>(1e-5, x.type());
auto norm = where<T>(tmp_norm > eps2, tmp_norm, eps2);
out = out / norm;
}
Expand Down Expand Up @@ -1387,8 +1381,8 @@ Tensor elu_decomp(const Tensor& x, const float alpha) {

if (has_dynamic_shape(x_cast.shape())) {
zero = backend::full_with_tensor<T>(shape<T>(x_cast), 0, x_cast.dtype());
tmp_res = full<T>(empty_shape, alpha, x_cast.dtype()) *
(exp<T>(x_cast) - full<T>(empty_shape, 1, x_cast.dtype()));
tmp_res = full_scalar<T>(alpha, x_cast.dtype()) *
(exp<T>(x_cast) - full_scalar<T>(1, x_cast.dtype()));
} else {
zero = full<T>(x_cast.shape(), 0, x_cast.type());
tmp_res = alpha * (exp<T>(x_cast) - 1);
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/primitive/manual/manual_primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ Tensor full(const IntArray& shape,
return backend::full<T>(shape, value, dtype, place);
}

template <typename T>
Tensor full_scalar(const Scalar& value,
DataType dtype = DataType::FLOAT32,
Place place = Place()) {
// empty_shape means x.shape=[]
std::vector<int64_t> empty_shape;
return backend::full<T>(empty_shape, value, dtype, place);
}

template <typename T>
Tensor assign_out_(const Tensor& x, const Tensor& output) {
return backend::assign_out_<T>(x, output);
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,8 @@ void leaky_relu_grad(const Tensor& out,
template <typename T>
void sigmoid_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
set_output<T>(out_grad * (out * (1 - out)), x_grad);
auto one_tensor = full_scalar<T>(1.0, out.dtype());
set_output<T>(out_grad * (out * (one_tensor - out)), x_grad);
}
}

Expand Down

0 comments on commit f693970

Please sign in to comment.