Skip to content

Add multi tensor for adam #38010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions paddle/fluid/operators/optimizers/adam_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,18 @@ __global__ void AdamKernelREG(MT beta1, MT beta2, MT epsilon, MT beta1_pow_,
MT beta1_pow = beta1_pow_;
MT beta2_pow = beta2_pow_;

lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);

int id = blockIdx.x * blockDim.x + threadIdx.x;

for (; id < ndim; id += gridDim.x * blockDim.x) {
MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
MT g = static_cast<MT>(grad[id]);
MT mom1 = moment1[id];
MT mom2 = moment2[id];
MT mom1 = static_cast<MT>(moment1[id]);
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));

MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));

moment1_out[id] = mom1;
moment2_out[id] = mom2;
Expand All @@ -65,9 +63,6 @@ __global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon,
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;

lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);

int id = blockIdx.x * blockDim.x + threadIdx.x;

for (; id < ndim; id += gridDim.x * blockDim.x) {
Expand All @@ -77,8 +72,9 @@ __global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon,
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));

MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));

moment1_out[id] = mom1;
moment2_out[id] = mom2;
Expand All @@ -105,8 +101,6 @@ __global__ void SparseAdamCUDAKernelREG(
int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
MT lr = *lr_;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);

for (; id < ndim; id += blockDim.x * gridDim.x) {
auto row_idx =
Expand All @@ -122,8 +116,10 @@ __global__ void SparseAdamCUDAKernelREG(
: static_cast<MT>(0);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr * (mom1 / (sqrt(mom2) +
epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));

MT denom =
(sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));

// Write back to global memory
mom1_out_[id] = mom1;
Expand Down
138 changes: 138 additions & 0 deletions paddle/fluid/operators/optimizers/merged_adam_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/optimizers/merged_adam_op.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

class MergedAdamOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {}

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto param_dtype =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(param_dtype, ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "Beta1Pow" || var_name == "Beta2Pow" ||
var_name == "SkipUpdate") {
return expected_kernel_type;
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
};

class MergedAdamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param", "(Tensor, default Tensor<float>) Input parameter")
.AsDuplicable();
AddInput("Grad", "(Tensor, default Tensor<float>) Input gradient")
.AsDuplicable();
AddInput("LearningRate", "(Tensor, default Tensor<float>) Learning rate")
.AsDuplicable();
AddInput("Moment1", "(Tensor, default Tensor<float>) Input first moment")
.AsDuplicable();
AddInput("Moment2", "(Tensor, default Tensor<float>) Input second moment")
.AsDuplicable();
AddInput("Beta1Pow",
"(Tensor, default Tensor<float>) Input beta1 power accumulator")
.AsDuplicable();
AddInput("Beta2Pow",
"(Tensor, default Tensor<float>) Input beta2 power accumulator")
.AsDuplicable();
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDispensable()
.AsDuplicable();

AddOutput("ParamOut", "(Tensor) Output parameter").AsDuplicable();
AddOutput("Moment1Out", "(Tensor) Output first moment").AsDuplicable();
AddOutput("Moment2Out", "(Tensor) Output second moment").AsDuplicable();
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator")
.AsDuplicable();
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator")
.AsDuplicable();
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable()
.AsDuplicable();

AddAttr<float>("beta1",
"(float, default 0.9) "
"Exponential decay rate for the "
"first moment estimates.")
.SetDefault(0.9f);
AddAttr<float>("beta2",
"(float, default 0.999) "
"exponential decay rate for the "
"second moment estimates.")
.SetDefault(0.999f);
AddAttr<float>("epsilon",
"(float, default 1.0e-8) "
"Constant for numerical stability")
.SetDefault(1.0e-8f);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
// TODO(zhiqiu): We could set Beta1PowOut and Beta2PowOut
// as dispensable since they are not used when use_global_beta_pow is true.
AddAttr<bool>("use_global_beta_pow",
"(bool, default false) "
"Whether to use global beta_pow for whole model instead of "
"creating beta_pow for each parameter.")
.SetDefault(false);

AddComment(R"DOC(
Adam Optimizer.
This implements the Adam optimizer from Section 2 of the Adam
paper : https://arxiv.org/abs/1412.6980.
Adam is a first-order gradient-based optimization method based on
adaptive estimates of lower-order moments.
Adam updates:
$$
moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\
moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\
learning\_rate = learning\_rate *
\frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\
param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon}
$$
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(merged_adam, ops::MergedAdamOp,
ops::MergedAdamOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(merged_adamw, ops::MergedAdamOp,
ops::MergedAdamOpMaker);

REGISTER_OP_CPU_KERNEL(
merged_adam,
ops::MergedAdamOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::MergedAdamOpKernel<paddle::platform::CPUDeviceContext, double>);
Loading