Skip to content

Commit b8ae385

Browse files
Silv3SchenwhqlYuanRisheng
authored
[PHI] Migrate softmax kernel (PaddlePaddle#47339)
* add extra attr property set * add type_info for all context * add onednn context to all context * fix context compile error * simplify conv kernel args * pass runtime attr into dev_ctx * fix marco error * clear conv_grad_kernel extra args * merge conv_grad_grad into conv_grad * clear conv2d_grad_grad extra attrs * remove redundant imports * migrate softmax * clear yaml and eager extra attr * fix conv1d error * change to thread local * fix npu compile failed * try to fix windows compile failed * add conv2d onednn phi kernel * fix ci bugs (#36) * fix compile bugs (#38) * fix extra input transform bug (#39) * support dynamic created attr (#40) * reset extra info gen code * rm conv_grad_grad kernel * reimpl pass attr adapting * add int attr support * remove vector inputnames creating * merge dev * fix map at error * adjust attribute * adapt funcs to PHI Co-authored-by: Chen Weihang <chenweihang@baidu.com> Co-authored-by: YuanRisheng <yuanrisheng@baidu.com>
1 parent f9a0605 commit b8ae385

File tree

6 files changed

+72
-116
lines changed

6 files changed

+72
-116
lines changed

paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "paddle/phi/core/kernel_registry.h"
2424

2525
USE_OP_ITSELF(softmax);
26-
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
26+
PD_DECLARE_KERNEL(softmax, OneDNN, ONEDNN);
2727
USE_OP_ITSELF(elementwise_add);
2828
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
2929
USE_OP_ITSELF(leaky_relu);

paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc

Lines changed: 0 additions & 111 deletions
This file was deleted.

paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN);
3434
USE_OP_ITSELF(relu);
3535
PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN);
3636
USE_OP_ITSELF(softmax);
37-
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
37+
PD_DECLARE_KERNEL(softmax, OneDNN, ONEDNN);
3838
USE_OP_ITSELF(conv2d);
3939
PD_DECLARE_KERNEL(conv2d, OneDNN, ONEDNN);
4040

paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
3232
USE_OP_ITSELF(relu);
3333
PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN);
3434
USE_OP_ITSELF(softmax);
35-
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
36-
35+
PD_DECLARE_KERNEL(softmax, OneDNN, ONEDNN);
3736
PD_DECLARE_KERNEL(softmax, CPU, ALL_LAYOUT);
3837

3938
namespace paddle {

paddle/phi/backends/onednn/onednn_reuse.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,12 +753,19 @@ class SoftmaxOneDNNHandler
753753
public:
754754
SoftmaxOneDNNHandler(const dnnl::engine onednn_engine,
755755
Place cpu_place,
756+
int axis,
756757
const DenseTensor* x,
757-
int axis)
758+
DenseTensor* out)
758759
: OneDNNHandlerNoCachingT<T,
759760
dnnl::softmax_forward,
760761
dnnl::softmax_backward>(onednn_engine,
761762
cpu_place) {
763+
PADDLE_ENFORCE_EQ(
764+
x->dims(),
765+
out->dims(),
766+
phi::errors::InvalidArgument(
767+
"The shape of input and output tensor must be identical."));
768+
762769
const int canonical_axis = funcs::CanonicalAxis(axis, x->dims().size());
763770
this->AcquireForwardPrimitiveDescriptor(
764771
dnnl::prop_kind::forward_scoring, x->mem_desc(), canonical_axis);
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/softmax_kernel.h"
16+
17+
#include "paddle/phi/backends/onednn/onednn_reuse.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
20+
namespace phi {
21+
22+
template <typename T, typename Context>
23+
void SoftmaxKernel(const Context& dev_ctx,
24+
const DenseTensor& x,
25+
int axis,
26+
DenseTensor* out) {
27+
funcs::SoftmaxOneDNNHandler<T> handler(
28+
dev_ctx.GetEngine(), dev_ctx.GetPlace(), axis, &x, out);
29+
30+
auto src_memory_p = handler.AcquireSrcMemory(&x);
31+
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
32+
if (x.IsSharedBufferWith(*out)) {
33+
dst_memory_p = src_memory_p;
34+
dev_ctx.template Alloc<T>(out);
35+
} else {
36+
dst_memory_p = handler.AcquireDstMemory(out);
37+
}
38+
auto softmax_p = handler.AcquireForwardPrimitive();
39+
40+
auto& astream = OneDNNContext::tls().get_stream();
41+
softmax_p->execute(
42+
astream, {{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}});
43+
astream.wait();
44+
45+
bool is_test = dev_ctx.HasDnnAttr("is_test")
46+
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test"))
47+
: false;
48+
if (!is_test) {
49+
T* out_data = dev_ctx.template Alloc<T>(out);
50+
std::for_each(out_data, &out_data[out->numel()], [](T& val) {
51+
val = std::max(val, static_cast<T>(exp(-64)));
52+
});
53+
}
54+
55+
out->set_mem_desc(dst_memory_p->get_desc());
56+
}
57+
58+
} // namespace phi
59+
60+
PD_REGISTER_KERNEL(
61+
softmax, OneDNN, ONEDNN, phi::SoftmaxKernel, float, phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)