Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,15 @@ static void CheckTensorNANOrInf(const std::string& op_type,

bool OperatorWithKernel::SupportsMKLDNN(
const proto::VarType::Type data_type) const {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
auto op_kernel_iter = OperatorWithKernel::AllOpKernels().find(type_);
if (op_kernel_iter == OperatorWithKernel::AllOpKernels().end()) {
VLOG(6) << "Warning: " << type_ << " don't find its MKLDNN Kernel in Fluid "
"Registered Kernels. And We don't "
"search its kernels in phi lib, "
"SupportsMKLDNN() return false.";
return false;
}
auto& op_kernels = op_kernel_iter->second;
return std::any_of(op_kernels.begin(), op_kernels.end(),
[data_type](OpKernelMap::const_reference kern_pair) {
return platform::is_cpu_place(kern_pair.first.place_) &&
Expand Down
12 changes: 3 additions & 9 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,9 @@ REGISTER_ACTIVATION_OP(hard_sigmoid, HardSigmoid, HardSigmoidFunctor,
HardSigmoidGradFunctor);
REGISTER_ACTIVATION_OP(logsigmoid, LogSigmoid, LogSigmoidFunctor,
LogSigmoidGradFunctor);
REGISTER_ACTIVATION_OP(log2, Log2, Log2Functor, Log2GradFunctor);
REGISTER_ACTIVATION_OP(log10, Log10, Log10Functor, Log10GradFunctor);
REGISTER_ACTIVATION_OP(log1p, Log1p, Log1pFunctor, Log1pGradFunctor);

/* ========================== sigmoid register =============================
*/
Expand Down Expand Up @@ -1867,15 +1870,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::LogGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);

REGISTER_ACTIVATION_CPU_KERNEL(log, Log, LogFunctor, LogGradFunctor);

REGISTER_OP_CPU_KERNEL(
log_grad_grad, ops::LogDoubleGradKernel<plat::CPUDeviceContext,
ops::LogGradGradFunctor<float>>,
ops::LogDoubleGradKernel<plat::CPUDeviceContext,
ops::LogGradGradFunctor<double>>,
ops::LogDoubleGradKernel<plat::CPUDeviceContext,
ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */

/* ========================== register checkpoint ===========================*/
Expand Down
151 changes: 5 additions & 146 deletions paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,11 @@ USE_PHI_DOUBLE_GRAD_FUNCTOR(Sigmoid)
USE_PHI_TRIPLE_GRAD_FUNCTOR(Sigmoid)
USE_PHI_FUNCTOR(LogSigmoid)
USE_PHI_FUNCTOR(HardSigmoid)
USE_PHI_FUNCTOR(Log)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Log)
USE_PHI_FUNCTOR(Log2)
USE_PHI_FUNCTOR(Log10)
USE_PHI_FUNCTOR(Log1p)

template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
Expand Down Expand Up @@ -448,88 +453,6 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
}
};

// log(x) = natural logarithm of x
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log();
}
};

template <typename T>
struct LogGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) / x);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// log2(x) = logarithm to the base 2 of the elements of x
template <typename T>
struct Log2Functor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log() / static_cast<T>(log(2));
}
};

// the gradient of log2(x) is 1/(x*ln(2))
template <typename T>
struct Log2GradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(1) / (x * static_cast<T>(log(2)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// log10(x) = logarithm to the base 10 of the elements of x
template <typename T>
struct Log10Functor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log() / static_cast<T>(log(10));
}
};

// the gradient of log10(x) is 1/(x*ln(10))
template <typename T>
struct Log10GradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(1) / (x * static_cast<T>(log(10)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// log1p(x) = natural logarithm of x+1
template <typename T>
struct Log1pFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = (static_cast<T>(1) + x).log();
}
};

template <typename T>
struct Log1pGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) / (x + static_cast<T>(1)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// square(x) = x^2
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -1197,37 +1120,6 @@ class SquareDoubleGradKernel
}
};

template <typename DeviceContext, typename Functor>
class LogDoubleGradKernel
: public SquareDoubleGradKernel<DeviceContext, Functor> {};

template <typename DeviceContext, typename Functor>
class ELUDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *X, *ddX, *dOut;
X = ddX = dOut = nullptr;
framework::Tensor *dX, *ddOut;
dX = ddOut = nullptr;

ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);

if (dX) dX->mutable_data<T>(X->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());

auto& place = ctx.template device_context<DeviceContext>();

Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
functor(place, X, ddX, ddOut, dOut, dX);
}
};

template <typename DeviceContext, typename Functor>
class CELUDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Expand Down Expand Up @@ -1522,36 +1414,6 @@ class LogitGradKernel : public framework::OpKernel<T> {
}
};

template <typename T>
struct LogGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad"));
// ddout = ddx / x; dx = -(dout / x) * (ddx / x)
// calculate dx first, so ddout can inplace ddx
if (dX) {
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad"));
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad"));
dx.device(*d) = dout * static_cast<T>(-1) * ddx / (x * x);
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(1) / x;
}
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

} // namespace operators
} // namespace paddle

Expand All @@ -1560,9 +1422,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \
__macro(log2, Log2, Log2Functor, Log2GradFunctor); \
__macro(log10, Log10, Log10Functor, Log10GradFunctor); \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
Expand Down
112 changes: 4 additions & 108 deletions paddle/fluid/operators/activation_op.kps
Original file line number Diff line number Diff line change
Expand Up @@ -131,27 +131,6 @@ struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;

// log(x) = log(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log(x));
}
};

template <typename T>
struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
// dx = dout / x
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout / x;
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> {
// square(x) = x * x
Expand Down Expand Up @@ -220,78 +199,6 @@ struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);

// log1p(x) = log(1 + x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log(one + x));
}
};

template <typename T>
struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);

// dx = dout / (1 + x)
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout / (one + x);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaLog2Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;

// log2(x) = log2(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log2(x));
}
};

template <typename T>
struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
T log_two = static_cast<T>(log(static_cast<MPType>(2.0f)));

// dx = dout / (x * log(2))
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout / (x * log_two);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaLog10Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;

// log10(x) = log10(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log10(x));
}
};

template <typename T>
struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
T log_ten = static_cast<T>(log(static_cast<MPType>(10.0f)));

// dx = dout / (x * log(10))
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout / (x * log_ten);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
Expand Down Expand Up @@ -773,6 +680,10 @@ USE_PHI_FUNCTOR(CudaELU)
USE_PHI_FUNCTOR(CudaSigmoid)
USE_PHI_FUNCTOR(CudaLogSigmoid)
USE_PHI_FUNCTOR(CudaHardSigmoid)
USE_PHI_FUNCTOR(CudaLog)
USE_PHI_FUNCTOR(CudaLog2)
USE_PHI_FUNCTOR(CudaLog10)
USE_PHI_FUNCTOR(CudaLog1p)

template <typename T>
using CudaELUGradNegativeAlphaFunctor =
Expand Down Expand Up @@ -975,18 +886,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::CudaExpm1GradFunctor<plat::float16>>);
/* ========================================================================== */

/* ========================== Log register ==================================*/
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor);

REGISTER_OP_CUDA_KERNEL(
log_grad_grad, ops::LogDoubleGradKernel<plat::CUDADeviceContext,
ops::LogGradGradFunctor<float>>,
ops::LogDoubleGradKernel<plat::CUDADeviceContext,
ops::LogGradGradFunctor<double>>,
ops::LogDoubleGradKernel<plat::CUDADeviceContext,
ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */

#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \
CudaSoftShrinkGradFunctor); \
Expand All @@ -995,9 +894,6 @@ REGISTER_OP_CUDA_KERNEL(
__macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \
__macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \
CudaReciprocalGradFunctor); \
__macro(log1p, Log1p, CudaLog1pFunctor, CudaLog1pGradFunctor); \
__macro(log2, Log2, CudaLog2Functor, CudaLog2GradFunctor); \
__macro(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor); \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
__macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor); \
__macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor); \
Expand Down
Loading