Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9363019
add nadam cpu op
megemini Apr 13, 2024
6669a7c
test nadam cpu op
megemini Apr 15, 2024
fc0fafc
add nadam gpu op
megemini Apr 16, 2024
b3a7a1a
add nadam docstring
megemini Apr 16, 2024
b41a902
mod nadam docstring
megemini Apr 17, 2024
2752082
add radam op
megemini Apr 18, 2024
8afd3da
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Apr 18, 2024
2526180
fix conflict
megemini Apr 18, 2024
e8701c7
codestyle
megemini Apr 18, 2024
0833f2d
fix & add unittest
megemini Apr 19, 2024
d51cddd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Apr 19, 2024
d32eb2d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Apr 19, 2024
60f8ebe
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Apr 19, 2024
f9cbe72
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Apr 27, 2024
bf775a5
remove momentum_decay_base
megemini May 5, 2024
e3dd75a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini May 5, 2024
8d83089
[Change] accumulator scalar to tensor & add unittest
megemini May 6, 2024
05dab93
remove glog
megemini May 6, 2024
5213c2a
make test lr smaller
megemini May 7, 2024
e191d17
make test lr smaller
megemini May 7, 2024
e5f1ba0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini May 7, 2024
3586f31
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini May 7, 2024
1f39304
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini May 7, 2024
73230bf
nadam & radam test timeout
megemini May 8, 2024
f9e94a1
use fake test data instead of uci data
megemini May 9, 2024
6038faf
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini May 10, 2024
57602c6
[Update] change nadam lr to 0.002
megemini May 13, 2024
45b6363
[Update] docstring
megemini May 14, 2024
482f4ae
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini May 15, 2024
9824325
[Update] unittest
megemini May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2204,6 +2204,18 @@
func : mv
backward : mv_grad

- op : nadam_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor momentum_decay_pow, Tensor beta2_pow, Tensor mu_product, Tensor moment1, Tensor moment2, Tensor master_param, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1.0e-8f, float momentum_decay = 0.004f, bool multi_precision = false)
output : Tensor(param_out), Tensor(momentum_decay_pow_out), Tensor(beta2_pow_out), Tensor(mu_product_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(master_param_out)
infer_meta :
func : NAdamInferMeta
kernel :
func : nadam
data_type : param
optional : master_param, master_param_out
inplace : (param -> param_out), (momentum_decay_pow -> momentum_decay_pow_out), (beta2_pow -> beta2_pow_out), (mu_product -> mu_product_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (master_param->master_param_out)
traits : pir::SideEffectTrait

- op : nanmedian
args : (Tensor x, IntArray axis = {}, bool keepdim = true, str mode="avg")
output : Tensor(out), Tensor(medians)
Expand Down Expand Up @@ -2443,6 +2455,18 @@
func : qr
backward : qr_grad

- op : radam_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor beta1_pow, Tensor beta2_pow, Tensor rho, Tensor moment1, Tensor moment2, Tensor master_param, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1.0e-8f, bool multi_precision = false)
output : Tensor(param_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(rho_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(master_param_out)
infer_meta :
func : RAdamInferMeta
kernel :
func : radam
data_type : param
optional : master_param, master_param_out
inplace : (param -> param_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (rho -> rho_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (master_param->master_param_out)
traits : pir::SideEffectTrait

- op : random_routing
args : (Tensor prob, Tensor topk_value, Tensor topk_idx)
output : Tensor(out)
Expand Down
165 changes: 165 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3592,6 +3592,89 @@ void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
out->set_dtype(ins[0]->dtype());
}

void NAdamInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& momentum_decay_pow,
const MetaTensor& beta2_pow,
const MetaTensor& mu_product,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& master_param,
float beta1,
float beta2,
float epsilon,
float momentum_decay,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* momentum_decay_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* mu_product_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* master_param_out) {
auto param_dim = param.dims();
PADDLE_ENFORCE_EQ(param_dim,
moment1.dims(),
phi::errors::InvalidArgument(
"Param and Momentum input of NAdamOp "
"should have the same dimension. But received "
"Param's dim [%s] and Moment1 [%s]",
param_dim,
moment1.dims()));
PADDLE_ENFORCE_EQ(param_dim,
moment2.dims(),
phi::errors::InvalidArgument(
"Param and Momentum input of NAdamOp "
"should have the same dimension. But received "
"Param's dim [%s] and Moment2 [%s]",
param_dim,
moment2.dims()));

auto lr_dim = learning_rate.dims();
PADDLE_ENFORCE_EQ(common::product(lr_dim),
1,
phi::errors::InvalidArgument(
"Learning Rate of NAdamOp should be a scalar. But "
"received LearningRate's dim [%s]",
common::product(lr_dim)));

if (master_param.initialized()) {
PADDLE_ENFORCE_EQ(param_dim,
master_param.dims(),
errors::InvalidArgument(
"Param and MasterParam input of NAdamOp should "
"have same dimension. But "
"received Param dims: [%s], MasterParam dims: [%s].",
param_dim,
master_param.dims()));
}

param_out->set_dims(param_dim);
param_out->set_dtype(param.dtype());

momentum_decay_pow_out->set_dims(momentum_decay_pow.dims());
momentum_decay_pow_out->set_dtype(momentum_decay_pow.dtype());
beta2_pow_out->set_dims(beta2_pow.dims());
beta2_pow_out->set_dtype(beta2_pow.dtype());
mu_product_out->set_dims(mu_product.dims());
mu_product_out->set_dtype(mu_product.dtype());

moment1_out->set_dims(param_dim);
moment1_out->set_dtype(moment1.dtype());
moment2_out->set_dims(param_dim);
moment2_out->set_dtype(moment2.dtype());

if (multi_precision && master_param.initialized()) {
auto MPType = (param.dtype() == phi::DataType::FLOAT16 ||
param.dtype() == phi::DataType::BFLOAT16)
? phi::DataType::FLOAT32
: param.dtype();
master_param_out->set_dims(param_dim);
master_param_out->set_dtype(MPType);
}
}

void NceInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down Expand Up @@ -3785,6 +3868,88 @@ void QuantizeLinearInferMeta(const MetaTensor& x,
}
}

void RAdamInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& beta1_pow,
const MetaTensor& beta2_pow,
const MetaTensor& rho,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& master_param,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* beta1_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* rho_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* master_param_out) {
auto param_dim = param.dims();
PADDLE_ENFORCE_EQ(param_dim,
moment1.dims(),
phi::errors::InvalidArgument(
"Param and Momentum input of RAdamOp "
"should have the same dimension. But received "
"Param's dim [%s] and Moment1 [%s]",
param_dim,
moment1.dims()));
PADDLE_ENFORCE_EQ(param_dim,
moment2.dims(),
phi::errors::InvalidArgument(
"Param and Momentum input of RAdamOp "
"should have the same dimension. But received "
"Param's dim [%s] and Moment2 [%s]",
param_dim,
moment2.dims()));

auto lr_dim = learning_rate.dims();
PADDLE_ENFORCE_EQ(common::product(lr_dim),
1,
phi::errors::InvalidArgument(
"Learning Rate of RAdamOp should be a scalar. But "
"received LearningRate's dim [%s]",
common::product(lr_dim)));

if (master_param.initialized()) {
PADDLE_ENFORCE_EQ(param_dim,
master_param.dims(),
errors::InvalidArgument(
"Param and MasterParam input of RAdamOp should "
"have same dimension. But "
"received Param dims: [%s], MasterParam dims: [%s].",
param_dim,
master_param.dims()));
}

param_out->set_dims(param_dim);
param_out->set_dtype(param.dtype());

beta1_pow_out->set_dims(beta1_pow.dims());
beta1_pow_out->set_dtype(beta1_pow.dtype());
beta2_pow_out->set_dims(beta2_pow.dims());
beta2_pow_out->set_dtype(beta2_pow.dtype());
rho_out->set_dims(rho.dims());
rho_out->set_dtype(rho.dtype());

moment1_out->set_dims(param_dim);
moment1_out->set_dtype(moment1.dtype());
moment2_out->set_dims(param_dim);
moment2_out->set_dtype(moment2.dtype());

if (multi_precision && master_param.initialized()) {
auto MPType = (param.dtype() == phi::DataType::FLOAT16 ||
param.dtype() == phi::DataType::BFLOAT16)
? phi::DataType::FLOAT32
: param.dtype();
master_param_out->set_dims(param_dim);
master_param_out->set_dtype(MPType);
}
}

void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
Expand Down
43 changes: 43 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,28 @@ void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out);

void NAdamInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& momentum_decay_pow,
const MetaTensor& beta2_pow,
const MetaTensor& mu_product,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& master_param,
float beta1,
float beta2,
float epsilon,
float momentum_decay,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* momentum_decay_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* mu_product_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* master_param_outs);

void NceInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down Expand Up @@ -697,6 +719,27 @@ void QuantizeLinearInferMeta(const MetaTensor& x,
MetaTensor* out_accum,
MetaTensor* out_state);

void RAdamInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& beta1_pow,
const MetaTensor& beta2_pow,
const MetaTensor& rho,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& master_param,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* beta1_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* rho_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* master_param_outs);

void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/kernels/cpu/nadam_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/nadam_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/nadam_kernel_impl.h"

PD_REGISTER_KERNEL(nadam, CPU, ALL_LAYOUT, phi::NAdamKernel, float, double) {}
21 changes: 21 additions & 0 deletions paddle/phi/kernels/cpu/radam_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/radam_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/radam_kernel_impl.h"

PD_REGISTER_KERNEL(radam, CPU, ALL_LAYOUT, phi::RAdamKernel, float, double) {}
Loading