Skip to content

Commit 296bfbc

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop
2 parents d0218fb + 687902f commit 296bfbc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1482
-566
lines changed

paddle/fluid/inference/api/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ cc_test(test_paddle_inference_api SRCS api_tester.cc DEPS paddle_inference_api)
5656

5757
if(WITH_TESTING)
5858
if (NOT APPLE AND NOT WIN32)
59-
inference_base_test(test_api_impl SRCS api_impl_tester.cc DEPS paddle_inference_shared
60-
ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR})
59+
if (WITH_GPU)
60+
inference_base_test(test_api_impl SRCS api_impl_tester.cc DEPS paddle_inference_shared
61+
ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR})
62+
endif()
6163
elseif(WIN32)
6264
inference_base_test(test_api_impl SRCS api_impl_tester.cc DEPS ${inference_deps}
6365
ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR})

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,9 @@ inference_analysis_api_test(test_analyzer_pyramid_dnn ${PYRAMID_DNN_INSTALL_DIR}
299299
set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie")
300300
download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_model.tar.gz" aa59192dd41ed377f9f168e3a1309fa6 "Ernie_data.txt.tar.gz" 5396e63548edad7ca561e7e26a9476d1)
301301
download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz" 73beea65abda2edb61c1662cd3180c62)
302-
inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} analyzer_ernie_tester.cc)
302+
if (WITH_GPU)
303+
inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} analyzer_ernie_tester.cc)
304+
endif()
303305
inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR} analyzer_ernie_int8_tester.cc)
304306

305307
# Ernie large
@@ -551,7 +553,9 @@ endif()
551553
# bert, max_len=20, embedding_dim=128
552554
set(BERT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/bert_emb128")
553555
download_model_and_data_without_verify(${BERT_INSTALL_DIR} "bert_emb128_model.tar.gz" "bert_data_len20.txt.tar.gz")
554-
inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert_tester.cc)
556+
if (WITH_GPU)
557+
inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert_tester.cc)
558+
endif()
555559

556560
# multiple models prediction
557561
set(MMP_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/multi_model_prediction")
@@ -741,13 +745,15 @@ set_tests_properties(lite_resnet50_test PROPERTIES TIMEOUT 120)
741745
set_tests_properties(test_analyzer_mobilenet_transpose PROPERTIES TIMEOUT 120)
742746
set_tests_properties(test_analyzer_resnet50 PROPERTIES TIMEOUT 120)
743747
set_tests_properties(test_analyzer_ner PROPERTIES TIMEOUT 120)
744-
set_tests_properties(test_analyzer_ernie PROPERTIES TIMEOUT 120)
745748
set_tests_properties(test_analyzer_ernie_int8 PROPERTIES TIMEOUT 120)
746749
set_tests_properties(test_analyzer_googlenet PROPERTIES TIMEOUT 120)
747750
set_tests_properties(test_analyzer_small_dam PROPERTIES TIMEOUT 120)
748751
set_tests_properties(test_analyzer_transformer PROPERTIES TIMEOUT 120)
749-
set_tests_properties(test_analyzer_bert PROPERTIES TIMEOUT 120)
750752
set_tests_properties(test_analyzer_mobilenet_depthwise_conv PROPERTIES TIMEOUT 120)
753+
if (WITH_GPU)
754+
set_tests_properties(test_analyzer_bert PROPERTIES TIMEOUT 120)
755+
set_tests_properties(test_analyzer_ernie PROPERTIES TIMEOUT 120)
756+
endif()
751757
if(WITH_GPU AND TENSORRT_FOUND)
752758
set_tests_properties(trt_mobilenet_test PROPERTIES TIMEOUT 120)
753759
if(WITH_MKLDNN)

paddle/fluid/operators/erfinv_op.cc

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/operators/erfinv_op.h"
15+
#include "paddle/fluid/framework/op_registry.h"
1616

1717
namespace paddle {
1818
namespace operators {
@@ -85,16 +85,3 @@ REGISTER_OPERATOR(
8585
paddle::operators::ErfinvInplaceInferer);
8686

8787
REGISTER_OPERATOR(erfinv_grad, paddle::operators::ErfinvGradOp);
88-
89-
REGISTER_OP_CPU_KERNEL(
90-
erfinv,
91-
paddle::operators::ErfinvKernel<paddle::platform::CPUDeviceContext, float>,
92-
paddle::operators::ErfinvKernel<paddle::platform::CPUDeviceContext,
93-
double>);
94-
95-
REGISTER_OP_CPU_KERNEL(
96-
erfinv_grad,
97-
paddle::operators::ErfinvGradKernel<paddle::platform::CPUDeviceContext,
98-
float>,
99-
paddle::operators::ErfinvGradKernel<paddle::platform::CPUDeviceContext,
100-
double>);

paddle/fluid/operators/erfinv_op.h

Lines changed: 0 additions & 65 deletions
This file was deleted.

paddle/fluid/operators/eye_op.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/eye_op.h"
15+
#include "paddle/fluid/framework/op_registry.h"
1616

1717
namespace paddle {
1818
namespace operators {
@@ -82,14 +82,8 @@ Return an identity tensor whose shape is [num_rows, num_columns].
8282
} // namespace paddle
8383

8484
namespace ops = paddle::operators;
85-
using CPU = paddle::platform::CPUDeviceContext;
8685

8786
REGISTER_OPERATOR(
8887
eye, ops::EyeOp, ops::EyeOpMaker, ops::EyeOpVarTypeInference,
8988
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
9089
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
91-
92-
REGISTER_OP_CPU_KERNEL(eye, ops::EyeKernel<CPU, float>,
93-
ops::EyeKernel<CPU, double>,
94-
ops::EyeKernel<CPU, int64_t>, ops::EyeKernel<CPU, int>,
95-
ops::EyeKernel<CPU, paddle::platform::float16>);

paddle/fluid/operators/eye_op.cu

Lines changed: 0 additions & 24 deletions
This file was deleted.

paddle/fluid/operators/eye_op.h

Lines changed: 0 additions & 61 deletions
This file was deleted.

paddle/fluid/operators/eye_op_npu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/eye_op.h"
15+
#include "paddle/fluid/framework/op_registry.h"
1616
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
1717

1818
namespace paddle {

paddle/fluid/operators/log_softmax_op.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,17 @@ class LogSoftmaxOp : public framework::OperatorWithKernel {
3131
protected:
3232
framework::OpKernelType GetExpectedKernelType(
3333
const framework::ExecutionContext& ctx) const override {
34-
return framework::OpKernelType(
35-
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
36-
ctx.device_context());
34+
auto input_data_type =
35+
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
36+
37+
#ifdef PADDLE_WITH_MKLDNN
38+
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
39+
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
40+
framework::DataLayout::kMKLDNN,
41+
framework::LibraryType::kMKLDNN);
42+
}
43+
#endif
44+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
3745
}
3846
};
3947

@@ -48,6 +56,10 @@ class LogSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
4856
"The dimension index of Input(x) to perform log_softmax,"
4957
"default -1 for last dimension")
5058
.SetDefault(-1);
59+
AddAttr<bool>("use_mkldnn",
60+
"(bool, default false) Only used in mkldnn kernel")
61+
.SetDefault(false)
62+
.AsExtra();
5163
AddComment(R"DOC(
5264
LogSoftmax Operator.
5365
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/softmax_op.h"
16+
#include "paddle/fluid/platform/mkldnn_reuse.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using framework::Tensor;
22+
23+
template <typename T>
24+
class LogSoftmaxMKLDNNHandler
25+
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward> {
26+
public:
27+
LogSoftmaxMKLDNNHandler(const dnnl::engine mkldnn_engine,
28+
platform::Place cpu_place, const Tensor* x,
29+
const int axis)
30+
: platform::MKLDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward>(
31+
mkldnn_engine, cpu_place) {
32+
const auto logsoftmax_tz = phi::vectorize(x->dims());
33+
const auto md = dnnl::memory::desc(
34+
logsoftmax_tz, platform::MKLDNNGetDataType<T>(), x->format());
35+
36+
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference,
37+
md, axis);
38+
}
39+
};
40+
41+
template <typename T>
42+
class LogSoftmaxMKLDNNKernel : public framework::OpKernel<T> {
43+
public:
44+
void Compute(const framework::ExecutionContext& ctx) const override {
45+
auto& dev_ctx =
46+
ctx.template device_context<platform::MKLDNNDeviceContext>();
47+
const auto& mkldnn_engine = dev_ctx.GetEngine();
48+
49+
const Tensor* x = ctx.Input<Tensor>("X");
50+
Tensor* out = ctx.Output<Tensor>("Out");
51+
52+
int axis = ctx.Attr<int>("axis");
53+
axis = axis >= 0 ? axis : x->dims().size() + axis;
54+
55+
LogSoftmaxMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), x, axis);
56+
57+
auto src_memory_p = handler.AcquireSrcMemory(x);
58+
auto dst_memory_p = handler.AcquireDstMemory(out);
59+
60+
auto logsoftmax_p = handler.AcquireForwardPrimitive();
61+
62+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
63+
logsoftmax_p->execute(astream, {{DNNL_ARG_SRC, *src_memory_p},
64+
{DNNL_ARG_DST, *dst_memory_p}});
65+
astream.wait();
66+
67+
out->set_layout(framework::DataLayout::kMKLDNN);
68+
out->set_format(x->format());
69+
}
70+
};
71+
} // namespace operators
72+
} // namespace paddle
73+
74+
namespace ops = paddle::operators;
75+
76+
REGISTER_OP_KERNEL(log_softmax, MKLDNN, ::paddle::platform::CPUPlace,
77+
ops::LogSoftmaxMKLDNNKernel<float>,
78+
ops::LogSoftmaxMKLDNNKernel<paddle::platform::bfloat16>);

0 commit comments

Comments
 (0)