Skip to content

Commit e6ec98f

Browse files
authored
[Phi] Move softmax with cross entropy kernel into phi (#40832)
* add cross_entropy_with_softmax phi kernel * remove softmax_with_cross_entropy kernel * add softmax_with_cross_entropy grad kernel * remove original op kernel * refine cross entropy impl * fix pointer error * revert kernel cu change * fix xpu failed * fix cinn failed * fix npu failed * add forward sig * add check_nan_inf for pt kernel * remove repeat cmake item * fix unittest error
1 parent d65a7a4 commit e6ec98f

22 files changed

+1867
-1339
lines changed

paddle/fluid/framework/new_executor/standalone_executor_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ USE_OP_ITSELF(elementwise_add);
3535
USE_OP_ITSELF(sigmoid);
3636
USE_OP_ITSELF(tanh);
3737
USE_OP_ITSELF(elementwise_mul);
38-
USE_OP(softmax_with_cross_entropy);
38+
USE_OP_ITSELF(softmax_with_cross_entropy);
3939
USE_OP_ITSELF(reduce_mean);
4040
USE_OP_ITSELF(reduce_sum);
4141
USE_OP_ITSELF(reduce_sum_grad);
@@ -83,6 +83,8 @@ PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT);
8383
PD_DECLARE_KERNEL(sgd, GPU, ALL_LAYOUT);
8484
PD_DECLARE_KERNEL(slice, GPU, ALL_LAYOUT);
8585
PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT);
86+
PD_DECLARE_KERNEL(cross_entropy_with_softmax, GPU, ALL_LAYOUT);
87+
PD_DECLARE_KERNEL(cross_entropy_with_softmax_grad, GPU, ALL_LAYOUT);
8688
PD_DECLARE_KERNEL(sqrt, GPU, ALL_LAYOUT);
8789

8890
DECLARE_double(eager_delete_tensor_gb);

paddle/fluid/framework/phi_utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey(
8787
} else if (kernel_type.library_type_ == LibraryType::kKP) {
8888
backend = phi::Backend::KPS;
8989
} else {
90-
// do
90+
// do nothing
9191
}
9292
paddle::experimental::DataLayout layout = kernel_type.data_layout_;
9393
paddle::experimental::DataType dtype =

paddle/fluid/imperative/prepared_operator.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,11 @@ static void PreparedOpRunPtImpl(
484484
pt_kernel(&pt_kernel_context);
485485
}
486486

487+
if (FLAGS_check_nan_inf) {
488+
framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
489+
op.Type(), outs, dev_ctx->GetPlace());
490+
}
491+
487492
if (FLAGS_benchmark) {
488493
dev_ctx->Wait();
489494
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)

paddle/fluid/operators/math/cross_entropy.cc

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/cross_entropy.h"
1616
#include "paddle/fluid/framework/convert_utils.h"
17+
#include "paddle/phi/backends/cpu/cpu_context.h"
1718

1819
namespace paddle {
1920
namespace platform {
@@ -89,38 +90,38 @@ struct HardLabelCrossEntropyCPUFunctorImpl {
8990
const int axis_dim_;
9091
};
9192

92-
template <typename T>
93-
class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
94-
public:
95-
void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out,
96-
const framework::Tensor* prob,
97-
const framework::Tensor* labels, const bool softLabel,
98-
const int ignore_index, const int axis_dim) {
99-
if (softLabel) {
100-
const int batch_size = prob->dims()[0];
101-
const int num_classes = prob->dims()[1];
102-
const int num_remain = num_classes / axis_dim;
103-
104-
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
105-
auto in = EigenMatrix<T>::From(*prob);
106-
auto lbl = EigenMatrix<T>::From(*labels);
107-
auto loss = EigenMatrix<T>::From(*out);
108-
109-
loss.device(*ctx.eigen_device()) =
110-
-((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
111-
.reshape(batch_axis_remain)
112-
.sum(Eigen::DSizes<int, 1>(1)));
113-
} else {
114-
HardLabelCrossEntropyCPUFunctorImpl<T> functor_impl(
115-
out, prob, labels, ignore_index, axis_dim);
116-
framework::VisitIntDataType(
117-
framework::TransToProtoVarType(labels->dtype()), functor_impl);
118-
}
93+
template <typename DeviceContext, typename T>
94+
void CrossEntropyFunctor<DeviceContext, T>::operator()(
95+
const DeviceContext& ctx, framework::Tensor* out,
96+
const framework::Tensor* prob, const framework::Tensor* labels,
97+
const bool softLabel, const int ignore_index, const int axis_dim) {
98+
if (softLabel) {
99+
const int batch_size = prob->dims()[0];
100+
const int num_classes = prob->dims()[1];
101+
const int num_remain = num_classes / axis_dim;
102+
103+
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
104+
auto in = EigenMatrix<T>::From(*prob);
105+
auto lbl = EigenMatrix<T>::From(*labels);
106+
auto loss = EigenMatrix<T>::From(*out);
107+
108+
loss.device(*ctx.eigen_device()) =
109+
-((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
110+
.reshape(batch_axis_remain)
111+
.sum(Eigen::DSizes<int, 1>(1)));
112+
} else {
113+
HardLabelCrossEntropyCPUFunctorImpl<T> functor_impl(out, prob, labels,
114+
ignore_index, axis_dim);
115+
framework::VisitIntDataType(framework::TransToProtoVarType(labels->dtype()),
116+
functor_impl);
119117
}
120-
};
118+
}
121119

122120
template class CrossEntropyFunctor<platform::CPUDeviceContext, float>;
123121
template class CrossEntropyFunctor<platform::CPUDeviceContext, double>;
122+
123+
template class CrossEntropyFunctor<phi::CPUContext, float>;
124+
template class CrossEntropyFunctor<phi::CPUContext, double>;
124125
} // namespace math
125126
} // namespace operators
126127
} // namespace paddle

paddle/fluid/operators/math/cross_entropy.cu

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/operators/math/cross_entropy.h"
1818
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
1919
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
20+
#include "paddle/phi/backends/gpu/gpu_context.h"
2021

2122
namespace paddle {
2223
namespace operators {
@@ -93,46 +94,48 @@ struct HardLabelCrossEntropyCUDAFunctorImpl {
9394
gpuStream_t stream_;
9495
};
9596

96-
template <typename T>
97-
class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
98-
public:
99-
void operator()(const platform::CUDADeviceContext& ctx,
100-
framework::Tensor* out, const framework::Tensor* prob,
101-
const framework::Tensor* labels, const bool softLabel,
102-
const int ignore_index, const int axis_dim) {
103-
const T* prob_data = prob->data<T>();
104-
T* loss_data = out->mutable_data<T>(ctx.GetPlace());
105-
106-
int batch_size = prob->dims()[0];
107-
int class_num = prob->dims()[1];
97+
template <typename DeviceContext, typename T>
98+
void CrossEntropyFunctor<DeviceContext, T>::operator()(
99+
const DeviceContext& ctx, framework::Tensor* out,
100+
const framework::Tensor* prob, const framework::Tensor* labels,
101+
const bool softLabel, const int ignore_index, const int axis_dim) {
102+
const T* prob_data = prob->data<T>();
103+
T* loss_data = out->mutable_data<T>(ctx.GetPlace());
104+
105+
int batch_size = prob->dims()[0];
106+
int class_num = prob->dims()[1];
108107
#ifdef __HIPCC__
109-
constexpr int kMaxBlockDim = 256;
108+
constexpr int kMaxBlockDim = 256;
110109
#else
111-
constexpr int kMaxBlockDim = 512;
110+
constexpr int kMaxBlockDim = 512;
112111
#endif
113112

114-
if (softLabel) {
115-
const T* label_data = labels->data<T>();
116-
int block = class_num > kMaxBlockDim
117-
? kMaxBlockDim
118-
: pow(2, static_cast<int>(std::log2(class_num)));
119-
120-
SoftCrossEntropyKernel<T><<<batch_size, block, 0, ctx.stream()>>>(
121-
loss_data, prob_data, label_data, class_num);
122-
} else {
123-
HardLabelCrossEntropyCUDAFunctorImpl<T> functor(
124-
loss_data, prob_data, labels->data(), batch_size, class_num,
125-
ignore_index, kMaxBlockDim, ctx.stream());
126-
framework::VisitDataType(framework::TransToProtoVarType(labels->dtype()),
127-
functor);
128-
}
113+
if (softLabel) {
114+
const T* label_data = labels->data<T>();
115+
int block = class_num > kMaxBlockDim
116+
? kMaxBlockDim
117+
: pow(2, static_cast<int>(std::log2(class_num)));
118+
119+
SoftCrossEntropyKernel<T><<<batch_size, block, 0, ctx.stream()>>>(
120+
loss_data, prob_data, label_data, class_num);
121+
} else {
122+
HardLabelCrossEntropyCUDAFunctorImpl<T> functor(
123+
loss_data, prob_data, labels->data(), batch_size, class_num,
124+
ignore_index, kMaxBlockDim, ctx.stream());
125+
framework::VisitDataType(framework::TransToProtoVarType(labels->dtype()),
126+
functor);
129127
}
130-
};
128+
}
131129

132130
template class CrossEntropyFunctor<platform::CUDADeviceContext, float>;
133131
template class CrossEntropyFunctor<platform::CUDADeviceContext, double>;
134132
template class CrossEntropyFunctor<platform::CUDADeviceContext,
135133
platform::float16>;
134+
135+
template class CrossEntropyFunctor<phi::GPUContext, float>;
136+
template class CrossEntropyFunctor<phi::GPUContext, double>;
137+
template class CrossEntropyFunctor<phi::GPUContext, platform::float16>;
138+
136139
} // namespace math
137140
} // namespace operators
138141
} // namespace paddle

paddle/fluid/operators/math/softmax.cu

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ using DataLayout = platform::DataLayout;
2929
template <typename T>
3030
using CudnnDataType = platform::CudnnDataType<T>;
3131

32-
template <typename T>
33-
void SoftmaxCUDNNFunctor<T>::operator()(
34-
const platform::CUDADeviceContext& context, const framework::Tensor* X,
32+
template <typename T, typename DeviceContext>
33+
void SoftmaxCUDNNFunctor<T, DeviceContext>::operator()(
34+
const DeviceContext& context, const framework::Tensor* X,
3535
framework::Tensor* Y) {
3636
// ------------------- cudnn descriptors ---------------------
3737
ScopedTensorDescriptor xDesc;
@@ -69,9 +69,9 @@ void SoftmaxCUDNNFunctor<T>::operator()(
6969
#endif
7070
}
7171

72-
template <typename T>
73-
void SoftmaxGradCUDNNFunctor<T>::operator()(
74-
const platform::CUDADeviceContext& context, const framework::Tensor* Y,
72+
template <typename T, typename DeviceContext>
73+
void SoftmaxGradCUDNNFunctor<T, DeviceContext>::operator()(
74+
const DeviceContext& context, const framework::Tensor* Y,
7575
const framework::Tensor* YGrad, framework::Tensor* XGrad) {
7676
// ------------------- cudnn descriptors ---------------------
7777
ScopedTensorDescriptor yDesc;
@@ -116,19 +116,31 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
116116
#endif
117117
}
118118

119-
template class SoftmaxCUDNNFunctor<float>;
120-
template class SoftmaxCUDNNFunctor<platform::float16>;
121-
template class SoftmaxGradCUDNNFunctor<float>;
122-
template class SoftmaxGradCUDNNFunctor<platform::float16>;
119+
template class SoftmaxCUDNNFunctor<float, platform::CUDADeviceContext>;
120+
template class SoftmaxCUDNNFunctor<platform::float16,
121+
platform::CUDADeviceContext>;
122+
template class SoftmaxGradCUDNNFunctor<float, platform::CUDADeviceContext>;
123+
template class SoftmaxGradCUDNNFunctor<platform::float16,
124+
platform::CUDADeviceContext>;
125+
template class SoftmaxCUDNNFunctor<float, phi::GPUContext>;
126+
template class SoftmaxCUDNNFunctor<platform::float16, phi::GPUContext>;
127+
template class SoftmaxGradCUDNNFunctor<float, phi::GPUContext>;
128+
template class SoftmaxGradCUDNNFunctor<platform::float16, phi::GPUContext>;
123129
#if CUDNN_VERSION_MIN(8, 1, 0)
124-
template class SoftmaxCUDNNFunctor<platform::bfloat16>;
125-
template class SoftmaxGradCUDNNFunctor<platform::bfloat16>;
130+
template class SoftmaxCUDNNFunctor<platform::bfloat16,
131+
platform::CUDADeviceContext>;
132+
template class SoftmaxGradCUDNNFunctor<platform::bfloat16,
133+
platform::CUDADeviceContext>;
134+
template class SoftmaxCUDNNFunctor<platform::bfloat16, phi::GPUContext>;
135+
template class SoftmaxGradCUDNNFunctor<platform::bfloat16, phi::GPUContext>;
126136
#endif
127137

128138
// MIOPEN do not support double
129139
#ifndef PADDLE_WITH_HIP
130-
template class SoftmaxCUDNNFunctor<double>;
131-
template class SoftmaxGradCUDNNFunctor<double>;
140+
template class SoftmaxCUDNNFunctor<double, platform::CUDADeviceContext>;
141+
template class SoftmaxGradCUDNNFunctor<double, platform::CUDADeviceContext>;
142+
template class SoftmaxCUDNNFunctor<double, phi::GPUContext>;
143+
template class SoftmaxGradCUDNNFunctor<double, phi::GPUContext>;
132144
#endif
133145

134146
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,

paddle/fluid/operators/math/softmax.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,18 @@ class SoftmaxGradFunctor {
3636
};
3737

3838
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
39-
template <typename T>
39+
template <typename T, typename DeviceContext>
4040
class SoftmaxCUDNNFunctor {
4141
public:
42-
void operator()(const platform::CUDADeviceContext& context,
43-
const framework::Tensor* X, framework::Tensor* Y);
42+
void operator()(const DeviceContext& context, const framework::Tensor* X,
43+
framework::Tensor* Y);
4444
};
4545

46-
template <typename T>
46+
template <typename T, typename DeviceContext>
4747
class SoftmaxGradCUDNNFunctor {
4848
public:
49-
void operator()(const platform::CUDADeviceContext& context,
50-
const framework::Tensor* Y, const framework::Tensor* y_grad,
51-
framework::Tensor* x_grad);
49+
void operator()(const DeviceContext& context, const framework::Tensor* Y,
50+
const framework::Tensor* y_grad, framework::Tensor* x_grad);
5251
};
5352

5453
#endif

paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class SequenceSoftmaxCUDNNKernel : public framework::OpKernel<T> {
5858
phi::make_ddim({1UL, end_pos - start_pos});
5959
x_i.Resize(dims_i);
6060
out_i.Resize(dims_i);
61-
math::SoftmaxCUDNNFunctor<T>()(
61+
math::SoftmaxCUDNNFunctor<T, platform::CUDADeviceContext>()(
6262
ctx.template device_context<platform::CUDADeviceContext>(), &x_i,
6363
&out_i);
6464
}
@@ -93,7 +93,7 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
9393
out_i.Resize(dims_i);
9494
out_grad_i.Resize(dims_i);
9595
x_grad_i.Resize(dims_i);
96-
math::SoftmaxGradCUDNNFunctor<T>()(
96+
math::SoftmaxGradCUDNNFunctor<T, platform::CUDADeviceContext>()(
9797
ctx.template device_context<platform::CUDADeviceContext>(), &out_i,
9898
&out_grad_i, &x_grad_i);
9999
}

paddle/fluid/operators/softmax_with_cross_entropy_op.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
15+
#include "paddle/fluid/framework/op_registry.h"
1616
#include "paddle/fluid/framework/op_version_registry.h"
17+
#include "paddle/phi/kernels/funcs/axis_utils.h"
1718

1819
namespace paddle {
1920
namespace operators {
@@ -335,12 +336,6 @@ REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
335336
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
336337
ops::SoftmaxWithCrossEntropyOpGrad,
337338
ops::SoftmaxWithCrossEntropyGradInplaceInferer);
338-
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
339-
ops::SoftmaxWithCrossEntropyKernel<float>,
340-
ops::SoftmaxWithCrossEntropyKernel<double>);
341-
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
342-
ops::SoftmaxWithCrossEntropyGradKernel<float>,
343-
ops::SoftmaxWithCrossEntropyGradKernel<double>);
344339

345340
REGISTER_OP_VERSION(softmax_with_cross_entropy)
346341
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)

0 commit comments

Comments
 (0)