Skip to content

Commit

Permalink
fix lamb beta1pow beta2pow update (PaddlePaddle#38518)
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy authored Dec 29, 2021
1 parent 72a41e5 commit 3672480
Showing 1 changed file with 108 additions and 72 deletions.
180 changes: 108 additions & 72 deletions paddle/fluid/operators/optimizers/lamb_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,16 @@ struct LambMomentREGUpdateFunctor {
const bool* skip_update_;

LambMomentREGUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon,
MT beta1_pow, MT* beta1_pow_out, MT beta2_pow,
MT* beta2_pow_out, const MT* mom1, MT* mom1_out,
const MT* mom2, MT* mom2_out, const T* grad,
const MT* param, MT* trust_ratio_div,
const bool* skip_update)
MT beta1_pow, MT beta2_pow, const MT* mom1,
MT* mom1_out, const MT* mom2, MT* mom2_out,
const T* grad, const MT* param,
MT* trust_ratio_div, const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -95,10 +92,6 @@ struct LambMomentREGUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
};

Expand All @@ -113,9 +106,7 @@ struct LambMomentMENUpdateFunctor {
MT epsilon_;

const MT* beta1_pow_;
MT* beta1_pow_out_;
const MT* beta2_pow_;
MT* beta2_pow_out_;
const MT* moment1_;
MT* moment1_out_;
const MT* moment2_;
Expand All @@ -126,8 +117,7 @@ struct LambMomentMENUpdateFunctor {
const bool* skip_update_;

LambMomentMENUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon,
const MT* beta1_pow, MT* beta1_pow_out,
const MT* beta2_pow, MT* beta2_pow_out,
const MT* beta1_pow, const MT* beta2_pow,
const MT* mom1, MT* mom1_out, const MT* mom2,
MT* mom2_out, const T* grad, const MT* param,
MT* trust_ratio_div, const bool* skip_update)
Expand All @@ -136,9 +126,7 @@ struct LambMomentMENUpdateFunctor {
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -168,10 +156,6 @@ struct LambMomentMENUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
};

Expand All @@ -183,9 +167,7 @@ struct SparseLambMomentREGUpdateFunctor {
T epsilon_;

T beta1_pow_;
T* beta1_pow_out_;
T beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
Expand All @@ -201,20 +183,18 @@ struct SparseLambMomentREGUpdateFunctor {
const bool* skip_update_;

SparseLambMomentREGUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
T beta1_pow, T* beta1_pow_out, T beta2_pow,
T* beta2_pow_out, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* grad,
const T* param, T* trust_ratio_div,
const int64_t* rows, int64_t row_numel,
int64_t row_count, const bool* skip_update)
T beta1_pow, T beta2_pow, const T* mom1,
T* mom1_out, const T* mom2, T* mom2_out,
const T* grad, const T* param,
T* trust_ratio_div, const int64_t* rows,
int64_t row_numel, int64_t row_count,
const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -246,10 +226,6 @@ struct SparseLambMomentREGUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}

inline HOSTDEVICE void operator()(size_t i) const {
Expand All @@ -270,9 +246,7 @@ struct SparseLambMomentMENUpdateFunctor {
T epsilon_;

const T* beta1_pow_;
T* beta1_pow_out_;
const T* beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
Expand All @@ -288,8 +262,7 @@ struct SparseLambMomentMENUpdateFunctor {
const bool* skip_update_;

SparseLambMomentMENUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
const T* beta1_pow, T* beta1_pow_out,
const T* beta2_pow, T* beta2_pow_out,
const T* beta1_pow, const T* beta2_pow,
const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* grad, const T* param,
T* trust_ratio_div, const int64_t* rows,
Expand All @@ -300,9 +273,7 @@ struct SparseLambMomentMENUpdateFunctor {
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -334,10 +305,6 @@ struct SparseLambMomentMENUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}

inline HOSTDEVICE void operator()(size_t i) const {
Expand All @@ -350,11 +317,44 @@ struct SparseLambMomentMENUpdateFunctor {
}
};

template <typename T, bool IsMultiPrecision>
struct LambParamUpateFunctor {
using MT = typename std::conditional<
IsMultiPrecision, typename details::MPTypeTrait<T>::Type, T>::type;
template <typename MT, bool NeedUpdateBetaPow /*=true*/>
struct LambBetaPowUpdateFunctor {
void SetBetaPows(const MT* beta1pow, const MT* beta2pow, MT* beta1pow_out,
MT* beta2pow_out, MT beta1, MT beta2) {
beta1pow_ = beta1pow;
beta2pow_ = beta2pow;
beta1pow_out_ = beta1pow_out;
beta2pow_out_ = beta2pow_out;
beta1_ = beta1;
beta2_ = beta2;
}

HOSTDEVICE void UpdateBetaPow(size_t i) const {
if (i == 0) {
beta1pow_out_[0] = beta1pow_[0] * beta1_;
beta2pow_out_[0] = beta2pow_[0] * beta2_;
}
}

private:
const MT* beta1pow_;
const MT* beta2pow_;
MT* beta1pow_out_;
MT* beta2pow_out_;
MT beta1_;
MT beta2_;
};

template <typename MT>
struct LambBetaPowUpdateFunctor<MT, /*NeedUpdateBetaPow=*/false> {
void SetBetaPows(const MT* beta1pow, const MT* beta2pow, MT* beta1pow_out,
MT* beta2pow_out, MT beta1, MT beta2) {}
HOSTDEVICE void UpdateBetaPow(size_t) const {}
};

template <typename T, typename MT, bool IsMultiPrecision, bool UpdateBetaPow>
struct LambParamUpateFunctor
: public LambBetaPowUpdateFunctor<MT, UpdateBetaPow> {
const MT* lr_;
const T* param_;
const MT* master_param_;
Expand Down Expand Up @@ -396,6 +396,7 @@ struct LambParamUpateFunctor {
if (IsMultiPrecision) {
master_param_out_[i] = param_out;
}
this->UpdateBetaPow(i);
}
};

Expand Down Expand Up @@ -501,15 +502,19 @@ class LambOpKernel : public framework::OpKernel<T> {
: nullptr;

// Update moments
bool should_update_beta_pow_later = false;
const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr;
MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr;
VLOG(10) << "Beta1Pow place: " << beta1_pow.place()
<< " , Beta2Pow place: " << beta2_pow.place();
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = grad_var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(ctx.GetPlace()) &&
beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
LambMomentREGUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay, beta1, beta2, epsilon, *beta1_pow.template data<MT>(),
nullptr, *beta2_pow.template data<MT>(), nullptr,
mom1.template data<MT>(),
*beta2_pow.template data<MT>(), mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
Expand All @@ -523,12 +528,17 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow_out.template mutable_data<MT>(platform::CPUPlace())[0] =
beta2 * beta2_pow.template data<MT>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr =
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
beta2_pow_out_ptr =
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
should_update_beta_pow_later = true;
LambMomentMENUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay, beta1, beta2, epsilon, beta1_pow.template data<MT>(),
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace()),
beta2_pow.template data<MT>(),
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace()),
mom1.template data<MT>(),
weight_decay, beta1, beta2, epsilon,
static_cast<const MT*>(beta1_pow_ptr),
static_cast<const MT*>(beta2_pow_ptr), mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
Expand All @@ -542,7 +552,12 @@ class LambOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(IsMultiPrecision, false,
platform::errors::Unimplemented(
"SelectedRows gradient is not supported when "
"multi_precision=True"));
"multi_precision=True."));
constexpr bool kIsSameType = std::is_same<T, MT>::value;
PADDLE_ENFORCE_EQ(kIsSameType, true,
platform::errors::Unimplemented(
"SelectedRows gradient is not supported when "
"multi_precision=True."));
auto& grad = GET_DATA_SAFELY(ctx.Input<framework::SelectedRows>("Grad"),
"Input", "Grad", "Lamb");
if (grad.rows().size() == 0) {
Expand Down Expand Up @@ -582,8 +597,8 @@ class LambOpKernel : public framework::OpKernel<T> {
SparseLambMomentREGUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay), static_cast<T>(beta1),
static_cast<T>(beta2), static_cast<T>(epsilon),
*beta1_pow.template data<T>(), nullptr,
*beta2_pow.template data<T>(), nullptr, mom1.template data<T>(),
*beta1_pow.template data<T>(), *beta2_pow.template data<T>(),
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
Expand All @@ -595,14 +610,18 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
static_cast<T>(beta2) * beta2_pow.template data<T>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr =
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
beta2_pow_out_ptr =
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
should_update_beta_pow_later = true;
SparseLambMomentMENUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay), static_cast<T>(beta1),
static_cast<T>(beta2), static_cast<T>(epsilon),
beta1_pow.template data<T>(),
beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta2_pow.template data<T>(),
beta2_pow_out.template mutable_data<T>(ctx.GetPlace()),
mom1.template data<T>(),
reinterpret_cast<const T*>(beta1_pow_ptr),
reinterpret_cast<const T*>(beta2_pow_ptr), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
Expand Down Expand Up @@ -639,14 +658,31 @@ class LambOpKernel : public framework::OpKernel<T> {
}
trust_ratio_div_norm.device(*place) = t.square().sum().sqrt();

LambParamUpateFunctor<T, IsMultiPrecision> param_update_functor(
lr.template data<MT>(), static_cast<const T*>(param_ptr),
static_cast<const MT*>(master_param_ptr), p_norm_t.template data<MT>(),
trust_ratio_div.template data<MT>(),
trust_ratio_div_norm_t.template data<MT>(),
static_cast<T*>(param_out_ptr), static_cast<MT*>(master_param_out_ptr),
skip_update_flag);
for_range(param_update_functor);
#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \
do { \
LambParamUpateFunctor<T, MT, IsMultiPrecision, __should_update_beta_pow> \
param_update_functor( \
lr.template data<MT>(), static_cast<const T*>(param_ptr), \
static_cast<const MT*>(master_param_ptr), \
p_norm_t.template data<MT>(), trust_ratio_div.template data<MT>(), \
trust_ratio_div_norm_t.template data<MT>(), \
static_cast<T*>(param_out_ptr), \
static_cast<MT*>(master_param_out_ptr), skip_update_flag); \
if (__should_update_beta_pow) { \
param_update_functor.SetBetaPows(beta1_pow_ptr, beta2_pow_ptr, \
beta1_pow_out_ptr, beta2_pow_out_ptr, \
beta1, beta2); \
} \
for_range(param_update_functor); \
} while (0)

if (should_update_beta_pow_later) {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true);
} else {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false);
}

#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC
}
};

Expand Down

0 comments on commit 3672480

Please sign in to comment.