Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ XPUOpMap& get_kl2_ops() {
{"mean_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"mean", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"merged_adam", XPUKernelSet({phi::DataType::FLOAT32})},
{"merged_momentum",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"mish_grad", XPUKernelSet({phi::DataType::FLOAT32})},
Expand Down
231 changes: 231 additions & 0 deletions paddle/phi/kernels/xpu/adam_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,229 @@ void AdamDenseKernel(const Context& dev_ctx,
funcs::FreeData<float>(moment2, mom2_ptr);
funcs::FreeData<float>(learning_rate, lr_ptr);
}

template <typename T, typename Context>
void MergedAdamKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& param,
const std::vector<const DenseTensor*>& grad,
const std::vector<const DenseTensor*>& learning_rate,
const std::vector<const DenseTensor*>& moment1,
const std::vector<const DenseTensor*>& moment2,
const std::vector<const DenseTensor*>& beta1_pow,
const std::vector<const DenseTensor*>& beta2_pow,
const paddle::optional<std::vector<const DenseTensor*>>& master_param,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool multi_precision,
bool use_global_beta_pow,
std::vector<DenseTensor*> param_out,
std::vector<DenseTensor*> moment1_out,
std::vector<DenseTensor*> moment2_out,
std::vector<DenseTensor*> beta1_pow_out,
std::vector<DenseTensor*> beta2_pow_out,
std::vector<DenseTensor*> master_param_out) {
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;

auto beta1_ = beta1.to<float>();
auto beta2_ = beta2.to<float>();
auto epsilon_ = epsilon.to<float>();
int64_t step_ = 0;
int64_t mode_ = 2;
int64_t bias_correction_ = 1;
float weight_decay_ = 0.0;

DenseTensor lr_host;
lr_host.Resize(learning_rate[0]->dims());
dev_ctx.template HostAlloc<float>(&lr_host);
phi::Copy(dev_ctx, *learning_rate[0], CPUPlace(), false, &lr_host);
float lr_ = *(lr_host.template data<float>());

float beta1_pow_data;
if (beta1_pow[0]->place() == CPUPlace()) {
beta1_pow_data = *(beta1_pow[0]->data<float>());
} else {
DenseTensor beta1_pow_host;
beta1_pow_host.Resize(beta1_pow[0]->dims());
dev_ctx.template HostAlloc<float>(&beta1_pow_host);
phi::Copy(dev_ctx, *beta1_pow[0], CPUPlace(), false, &beta1_pow_host);
beta1_pow_data = *(beta1_pow_host.template data<float>());
}

float beta2_pow_data;
if (beta2_pow[0]->place() == CPUPlace()) {
beta2_pow_data = *(beta2_pow[0]->data<float>());
} else {
DenseTensor beta2_pow_host;
beta2_pow_host.Resize(beta2_pow[0]->dims());
dev_ctx.template HostAlloc<float>(&beta2_pow_host);
phi::Copy(dev_ctx, *beta2_pow[0], CPUPlace(), false, &beta2_pow_host);
beta2_pow_data = *(beta2_pow_host.template data<float>());
}

int param_num = param.size();
PADDLE_ENFORCE_EQ(param_num,
param_out.size(),
errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
param_out.size(),
param_num));
PADDLE_ENFORCE_EQ(
param_num,
moment1_out.size(),
errors::InvalidArgument(
"The size of Input(Moment1) must be equal to Input(Param), but got "
"the size of Input(Moment1) is %d, the size of Input(Param) is %d.",
moment1.size(),
param_num));
PADDLE_ENFORCE_EQ(
param_num,
moment2_out.size(),
errors::InvalidArgument(
"The size of Input(Moment1) must be equal to Input(Param), but got "
"the size of Input(Moment1) is %d, the size of Input(Param) is %d.",
moment2.size(),
param_num));
PADDLE_ENFORCE_EQ(param_num,
beta1_pow_out.size(),
errors::InvalidArgument(
"The size of Output(Beta1PowOut) must be equal to "
"Input(Param), but got the size of Output(Beta1PowOut) "
"is %d, the size of Input(Param) is %d.",
beta1_pow_out.size(),
param_num));
PADDLE_ENFORCE_EQ(param_num,
beta2_pow_out.size(),
errors::InvalidArgument(
"The size of Output(Beta2PowOut) must be equal to "
"Input(Param), but got the size of Output(Beta2PowOut) "
"is %d, the size of Input(Param) is %d.",
beta2_pow_out.size(),
param_num));
PADDLE_ENFORCE_EQ(
param_num,
grad.size(),
errors::InvalidArgument(
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grad.size(),
param_num));
PADDLE_ENFORCE_EQ(
param_num,
moment1.size(),
errors::InvalidArgument(
"The size of Input(Moment1) must be equal to Input(Param), but got "
"the size of Input(Moment1) is %d, the size of Input(Param) is %d.",
moment1.size(),
param_num));
PADDLE_ENFORCE_EQ(
param_num,
moment2.size(),
errors::InvalidArgument(
"The size of Input(Moment1) must be equal to Input(Param), but got "
"the size of Input(Moment1) is %d, the size of Input(Param) is %d.",
moment2.size(),
param_num));

std::vector<float*> param_list(param_num);
std::vector<float*> grad_list(param_num);
std::vector<float*> moment1_list(param_num);
std::vector<float*> moment2_list(param_num);
std::vector<int64_t> shape_list(param_num);

for (int j = 0; j < param_num; j++) {
param_list[j] = const_cast<float*>(param[j]->data<float>());
grad_list[j] = const_cast<float*>(grad[j]->data<float>());
moment1_list[j] = const_cast<float*>(moment1[j]->data<float>());
moment2_list[j] = const_cast<float*>(moment2[j]->data<float>());
shape_list[j] = param[j]->numel();

PADDLE_ENFORCE_EQ(
param[j],
param_out[j],
errors::InvalidArgument("The size of Input(Param) and Output(ParamOut) "
"must be the same Tensors."));
PADDLE_ENFORCE_EQ(
moment1[j],
moment1_out[j],
errors::InvalidArgument("The size of Input(Param) and Output(ParamOut) "
"must be the same Tensors."));
PADDLE_ENFORCE_EQ(
moment2[j],
moment2_out[j],
errors::InvalidArgument("The size of Input(Param) and Output(ParamOut) "
"must be the same Tensors."));

dev_ctx.template Alloc<float>(param_out[j]);
dev_ctx.template Alloc<float>(moment1_out[j]);
dev_ctx.template Alloc<float>(moment2_out[j]);
}

int r = xpu::multi_tensor_adam(dev_ctx.x_context(),
grad_list,
param_list,
moment1_list,
moment2_list,
shape_list,
lr_,
beta1_,
beta2_,
epsilon_,
step_,
mode_,
bias_correction_,
weight_decay_,
beta1_pow_data,
beta2_pow_data);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "merged_adam");

// update param, moment1, moment2
for (int i = 0; i < param_num; i++) {
phi::Copy(dev_ctx, *param[i], dev_ctx.GetPlace(), false, param_out[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个能优化吗?api内部inplace操作

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个确实是导致mergedAdam性能不够好的问题之一,原因就是api的算子接口没有按照标准的inplace来写,而pytorch那边不希望我们更改接口,所以只能暂时这样

phi::Copy(dev_ctx, *moment1[i], dev_ctx.GetPlace(), false, moment1_out[i]);
phi::Copy(dev_ctx, *moment2[i], dev_ctx.GetPlace(), false, moment2_out[i]);
}

if (!use_global_beta_pow) {
for (int i = 0; i < param_num; i++) {
if (beta1_pow[i]->place() == CPUPlace() &&
beta2_pow[i]->place() == CPUPlace()) {
funcs::SetBetaData<Context, float>(
*beta1_pow[i], beta1_pow_out[i], beta1_, dev_ctx);

funcs::SetBetaData<Context, float>(
*beta2_pow[i], beta2_pow_out[i], beta2_, dev_ctx);
} else {
float* beta1_pow_out_ptr = nullptr;
const float* beta1_pow_data = beta1_pow[i]->data<float>();
beta1_pow_out_ptr = dev_ctx.template Alloc<float>(beta1_pow_out[i]);
r = xpu::scale(dev_ctx.x_context(),
beta1_pow_data,
beta1_pow_out_ptr,
beta1_pow[i]->numel(),
false,
beta1_,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "merged_adam");

float* beta2_pow_out_ptr = nullptr;
const float* beta2_pow_data = beta2_pow[i]->data<float>();
beta2_pow_out_ptr = dev_ctx.template Alloc<float>(beta2_pow_out[i]);
r = xpu::scale(dev_ctx.x_context(),
beta2_pow_data,
beta2_pow_out_ptr,
beta2_pow[i]->numel(),
false,
beta2_,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "merged_adam");
}
}
}
}
} // namespace phi

PD_REGISTER_KERNEL(
Expand All @@ -252,3 +475,11 @@ PD_REGISTER_KERNEL(
kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
}

PD_REGISTER_KERNEL(merged_adam, XPU, ALL_LAYOUT, phi::MergedAdamKernel, float) {
// Skip beta1_pow, beta2_pow data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
}
4 changes: 3 additions & 1 deletion test/xpu/test_adam_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,9 @@ def adam_step(inputs, attributes):
moment1_out = beta1 * moment1 + (1 - beta1) * grad
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow)
param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon))
param_out = param - lr_t * (
moment1_out / (np.sqrt(moment2_out) + epsilon * np.sqrt(1 - beta2_pow))
)
return param_out, moment1_out, moment2_out


Expand Down
Loading