Skip to content

Commit 09289e6

Browse files
committed
add xpu lars_momentum/pow2_decay
*test=kunlun
1 parent 0aa344f commit 09289e6

File tree

7 files changed

+432
-2
lines changed

7 files changed

+432
-2
lines changed

cmake/external/xpu.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
1010
if(NOT DEFINED XPU_BASE_URL)
1111
set(XPU_BASE_URL_WITHOUT_DATE
1212
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
13-
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220718")
13+
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220719")
1414
else()
1515
set(XPU_BASE_URL "${XPU_BASE_URL}")
1616
endif()
@@ -19,7 +19,7 @@ endif()
1919
if(NOT DEFINED XPU_XDNN_BASE_URL)
2020
set(XPU_XDNN_BASE_URL_WITHOUT_DATE
2121
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
22-
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220718")
22+
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220719")
2323
else()
2424
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
2525
endif()
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_XPU
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
18+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename T>
24+
class LarsMomentumOpXPUKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& ctx) const override {
27+
bool multi_precision = ctx.Attr<bool>("multi_precision");
28+
auto param_out = ctx.MultiOutput<framework::LoDTensor>("ParamOut");
29+
auto velocity_out = ctx.MultiOutput<framework::LoDTensor>("VelocityOut");
30+
auto param = ctx.MultiInput<framework::LoDTensor>("Param");
31+
auto velocity = ctx.MultiInput<framework::LoDTensor>("Velocity");
32+
auto learning_rate = ctx.MultiInput<framework::LoDTensor>("LearningRate");
33+
auto grad = ctx.MultiInput<framework::LoDTensor>("Grad");
34+
auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay");
35+
auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam");
36+
auto master_param_out =
37+
ctx.MultiOutput<framework::LoDTensor>("MasterParamOut");
38+
T mu = static_cast<T>(ctx.Attr<float>("mu"));
39+
T lars_coeff = ctx.Attr<float>("lars_coeff");
40+
T epsilon = ctx.Attr<float>("epsilon");
41+
T rescale_grad = ctx.Attr<float>("rescale_grad");
42+
43+
std::vector<T*> param_list;
44+
std::vector<T*> grad_list;
45+
std::vector<T*> param_out_list;
46+
std::vector<float*> velocity_list;
47+
std::vector<float*> velocity_out_list;
48+
std::vector<float*> lrs;
49+
std::vector<int> param_sizes;
50+
51+
std::vector<float*> master_param_list;
52+
std::vector<float*> master_param_out_list;
53+
int op_num = param.size();
54+
for (int i = 0; i < op_num; ++i) {
55+
param_list.push_back(const_cast<T*>(param[i]->data<T>()));
56+
grad_list.push_back(const_cast<T*>(grad[i]->data<T>()));
57+
param_out_list.push_back(param_out[i]->mutable_data<T>(ctx.GetPlace()));
58+
velocity_list.push_back(const_cast<float*>(velocity[i]->data<float>()));
59+
velocity_out_list.push_back(
60+
velocity_out[i]->mutable_data<float>(ctx.GetPlace()));
61+
lrs.push_back(const_cast<float*>(learning_rate[i]->data<float>()));
62+
param_sizes.push_back(param[i]->numel());
63+
64+
PADDLE_ENFORCE_EQ(
65+
param_list[i],
66+
param_out_list[i],
67+
platform::errors::InvalidArgument(
68+
"Input(Param) and Output(ParamOut) must be the same Tensors."));
69+
PADDLE_ENFORCE_EQ(velocity_list[i],
70+
velocity_out_list[i],
71+
platform::errors::InvalidArgument(
72+
"Input(Velocity) and Output(VelocityOut) must be "
73+
"the same Tensors."));
74+
if (multi_precision) {
75+
master_param_list.push_back(
76+
const_cast<float*>(master_param[i]->data<float>()));
77+
master_param_out_list.push_back(
78+
master_param_out[i]->mutable_data<float>(ctx.GetPlace()));
79+
PADDLE_ENFORCE_EQ(master_param_list[i],
80+
master_param_out_list[i],
81+
platform::errors::InvalidArgument(
82+
"Input(MasterParam) and Output(MasterParamOut) "
83+
"must be the same Tensors."));
84+
} else {
85+
master_param_list.push_back(nullptr);
86+
master_param_out_list.push_back(nullptr);
87+
}
88+
}
89+
90+
auto& dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
91+
int r = lars_momentum(dev_ctx.x_context(),
92+
param_list,
93+
grad_list,
94+
velocity_list,
95+
lrs,
96+
master_param_list,
97+
param_out_list,
98+
velocity_out_list,
99+
master_param_out_list,
100+
weight_decay_arr,
101+
param_sizes,
102+
mu,
103+
lars_coeff,
104+
epsilon,
105+
rescale_grad);
106+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "lars_momentum");
107+
}
108+
};
109+
110+
} // namespace operators
111+
} // namespace paddle
112+
113+
namespace ops = paddle::operators;
114+
REGISTER_OP_XPU_KERNEL(lars_momentum, ops::LarsMomentumOpXPUKernel<float>);
115+
#endif
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifdef PADDLE_WITH_XPU
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/framework/operator.h"
18+
#include "paddle/fluid/framework/tensor.h"
19+
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
20+
#include "paddle/fluid/platform/macros.h"
21+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
22+
23+
namespace paddle {
24+
namespace operators {
25+
26+
template <typename T>
27+
class Pow2DecayWithLinearWarmupXPUOpKernel : public framework::OpKernel<T> {
28+
public:
29+
void Compute(const framework::ExecutionContext &ctx) const {
30+
const auto *lr = ctx.Input<framework::Tensor>("LearningRate");
31+
const auto *step = ctx.Input<framework::Tensor>("Step");
32+
auto *lr_out = ctx.Output<framework::Tensor>("LearningRateOut");
33+
auto *step_out = ctx.Output<framework::Tensor>("StepOut");
34+
PADDLE_ENFORCE_EQ(
35+
lr,
36+
lr_out,
37+
platform::errors::InvalidArgument("Input(LearningRate) and "
38+
"Output(LearningRateOut) "
39+
"must be the same."));
40+
PADDLE_ENFORCE_NOT_NULL(lr,
41+
platform::errors::InvalidArgument(
42+
"Input(LearingRate) should not be nullptr."));
43+
PADDLE_ENFORCE_EQ(step,
44+
step_out,
45+
platform::errors::InvalidArgument(
46+
"Input(Step) and Output(StepOut) must be the same."));
47+
PADDLE_ENFORCE_NOT_NULL(step,
48+
platform::errors::InvalidArgument(
49+
"Input(Step) should not be nullptr."));
50+
PADDLE_ENFORCE_EQ(
51+
step->IsInitialized(),
52+
true,
53+
platform::errors::InvalidArgument("Input(Step) must be initialized."));
54+
55+
auto warmup_steps = static_cast<size_t>(ctx.Attr<int64_t>("warmup_steps"));
56+
auto total_steps = static_cast<size_t>(ctx.Attr<int64_t>("total_steps"));
57+
PADDLE_ENFORCE_LE(warmup_steps,
58+
total_steps,
59+
platform::errors::InvalidArgument(
60+
"warmup_steps must not be larger than total_steps."));
61+
auto base_lr = ctx.Attr<float>("base_lr");
62+
auto end_lr = ctx.Attr<float>("end_lr");
63+
64+
auto *lr_data = lr_out->data<T>();
65+
auto *step_data = step_out->data<int64_t>();
66+
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
67+
int r = xpu::pow2_decay_with_linear_warmup(dev_ctx.x_context(),
68+
lr_data,
69+
step_data,
70+
warmup_steps,
71+
total_steps,
72+
base_lr,
73+
end_lr);
74+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pow2_decay_with_linear_warmup");
75+
}
76+
};
77+
78+
} // namespace operators
79+
} // namespace paddle
80+
81+
namespace ops = paddle::operators;
82+
REGISTER_OP_XPU_KERNEL(pow2_decay_with_linear_warmup,
83+
ops::Pow2DecayWithLinearWarmupXPUOpKernel<float>);
84+
#endif

paddle/fluid/platform/device/xpu/xpu2_op_list.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ XPUOpMap& get_kl2_ops() {
7171
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
7272
pOpKernelType(vartype::FP16, XPUPlace())})},
7373
{"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
74+
{"coalesce_tensor",
75+
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
7476
{"concat_grad",
7577
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
7678
pOpKernelType(vartype::FP16, XPUPlace())})},
@@ -255,6 +257,8 @@ XPUOpMap& get_kl2_ops() {
255257
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
256258
{"label_smooth",
257259
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
260+
{"lars_momentum",
261+
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
258262
{"layer_norm_grad",
259263
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
260264
{"layer_norm_grad",
@@ -334,6 +338,8 @@ XPUOpMap& get_kl2_ops() {
334338
pOpKernelType(vartype::FP16, XPUPlace())})},
335339
{"pow", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
336340
{"pow_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
341+
{"pow2_decay_with_linear_warmup",
342+
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
337343
{"prior_box", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
338344
{"range",
339345
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),

python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
'dropout_float16',
8989
'dropout_grad_float16',
9090
"grad_add_float32", # no api for grad_add, skip
91+
"lars_momentum_float32",
9192
"resnet_unit",
9293
"resnet_unit_grad"
9394
]
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from paddle.fluid import core
20+
import sys
21+
22+
sys.path.append("..")
23+
from op_test import OpTest
24+
25+
alignment = 256
26+
import paddle
27+
from op_test_xpu import XPUOpTest
28+
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
29+
30+
paddle.enable_static()
31+
32+
33+
class XPUTestCoalesceTensorOp(XPUOpTestWrapper):
34+
35+
def __init__(self):
36+
self.op_name = 'coalesce_tensor'
37+
self.use_dynamic_create_class = False
38+
39+
class TestAllocContinuousSpace(XPUOpTest):
40+
41+
def setUp(self):
42+
self.op_type = "coalesce_tensor"
43+
self.use_xpu = True
44+
self.dtype, self.fluid_dtype = self.init_dtype()
45+
attrs = self.init_attr()
46+
self.copy_data = attrs["copy_data"]
47+
self.constant = attrs["constant"]
48+
self.set_constant = attrs["set_constant"]
49+
self.Inputs = self.init_input()
50+
self.Outputs, self.FusedOutput = self.init_output(
51+
self.Inputs, self.set_constant, self.constant)
52+
self.inputs = {'Input': self.Inputs}
53+
self.attrs = attrs
54+
self.outputs = {
55+
'Output': self.Outputs,
56+
'FusedOutput': self.FusedOutput
57+
}
58+
59+
def init_dtype(self):
60+
return np.float32, int(core.VarDesc.VarType.FP32)
61+
62+
def init_input(self):
63+
inputs = []
64+
inputs.append(("x1", np.random.random([20, 3]).astype(self.dtype)))
65+
inputs.append(("x2", np.random.random([20]).astype(self.dtype)))
66+
inputs.append(("x3", np.random.random([1]).astype(self.dtype)))
67+
inputs.append(("x4", np.random.random([200,
68+
30]).astype(self.dtype)))
69+
inputs.append(("x5", np.random.random([30]).astype(self.dtype)))
70+
inputs.append(("x6", np.random.random([1]).astype(self.dtype)))
71+
return inputs
72+
73+
def init_attr(self):
74+
return {
75+
"copy_data": True,
76+
"set_constant": False,
77+
"constant": 0.0,
78+
"dtype": self.fluid_dtype
79+
}
80+
81+
def init_output(self, input_list, set_constant, constant):
82+
inputs = []
83+
outputs = input_list
84+
85+
for input in input_list:
86+
length = len(input[1].flatten())
87+
aligned_len = (length + alignment) / alignment * alignment
88+
out = np.zeros(int(aligned_len))
89+
out[0:length] = input[1].flatten()
90+
inputs.append(out)
91+
92+
coalesce_tensor_var = np.concatenate([input for input in inputs])
93+
if set_constant:
94+
coalesce_tensor_var = np.ones(
95+
(len(coalesce_tensor_var))) * constant
96+
outputs = [(out[0],
97+
np.ones(out[1].shape).astype(self.dtype) * constant)
98+
for out in outputs]
99+
return outputs, coalesce_tensor_var
100+
101+
def test_check_output(self):
102+
self.check_output_with_place(place=core.XPUPlace(0),
103+
no_check_set=["FusedOutput"],
104+
atol=1e-5)
105+
106+
class TestAllocContinuousSpace2(TestAllocContinuousSpace):
107+
108+
def init_attr(self):
109+
return {
110+
"copy_data": False,
111+
"set_constant": True,
112+
"constant": 0.5,
113+
"dtype": self.fluid_dtype,
114+
"user_defined_size_of_dtype": 2
115+
}
116+
117+
def test_check_output(self):
118+
self.check_output_with_place(place=core.XPUPlace(0),
119+
no_check_set=["FusedOutput"],
120+
atol=1e-5)
121+
122+
123+
support_types = get_xpu_op_support_types('coalesce_tensor')
124+
for stype in support_types:
125+
create_test_class(globals(), XPUTestCoalesceTensorOp, stype)
126+
127+
if __name__ == '__main__':
128+
unittest.main()

0 commit comments

Comments
 (0)