Skip to content
Merged
6 changes: 6 additions & 0 deletions paddle/phi/common/amp_type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/float8_e4m3fn.h"
#include "paddle/phi/common/float8_e5m2.h"
Expand Down Expand Up @@ -52,5 +53,10 @@ class MPTypeTrait<phi::dtype::float8_e5m2> {
using Type = float;
};

template <>
struct MPTypeTrait<phi::dtype::complex<float16>> {
using type = phi::dtype::complex<float>;
};

} // namespace dtype
} // namespace phi
237 changes: 203 additions & 34 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3606,6 +3606,15 @@ struct CudaSquareGradFunctor<ComplexType<T>>
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaRsquareFunctor : public BaseActivationFunctor<T> {
// square(x) = 1 / (x * x)
T one = static_cast<T>(1.0f);
__device__ __forceinline__ T operator()(const T x) const {
return one / (x * x);
}
};

template <typename T>
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
Expand Down Expand Up @@ -3705,6 +3714,36 @@ struct CudaReciprocalGradFunctor<ComplexType<T>>
}
};

// for pow(x, -1)
template <typename T>
struct CudaReciprocalGradDepXFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);

// dx = -dout * out^2
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(-dout * (one / (x * x)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaReciprocalGradDepXFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);
// dx = -dout * out^2
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return -dout * conj(one / (x * x));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
Expand Down Expand Up @@ -4296,6 +4335,36 @@ struct CudaSqrtGradFunctor<ComplexType<T>>
}
};

// for pow(x, 0.5)
template <typename T>
struct CudaSqrtGradDepXFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

MPType one_half = static_cast<MPType>(0.5f);

// dx = dout * (0.5 * rsqrt(x))
__device__ __forceinline__ T operator()(const T dout, const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return dout * static_cast<T>(one_half * rsqrt(x));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSqrtGradDepXFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> one_half = static_cast<ComplexType<T>>(0.5f);

// dx = dout * conj(0.5 * rsqrt(x))
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return dout * conj(one_half / sqrt(x));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand All @@ -4307,6 +4376,18 @@ struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaRsqrtFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);

// rsqrt(x) = 1 / sqrt(x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> arg_x) const {
return one / sqrt(arg_x);
}
};

template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down Expand Up @@ -5407,81 +5488,169 @@ struct CudaCeilFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T, typename MPType>
template <typename T>
__device__ __forceinline__
typename std::enable_if<std::is_integral<T>::value, T>::type
compute_pow(const T a, const T b) {
typename std::enable_if<std::is_integral<T>::value, int64_t>::type
compute_pow(const T a, const double b) {
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
// On CUDAPlace, pow(3, 1) calls pow(float, float), and
// it will return a float number like 2.99... , which floor to 2
// when cast to int by default and it is wrong.
// Use llrint to cast it to the nearest integer, which is 3.
return llrint(pow(static_cast<double>(a), static_cast<double>(b)));
return llrint(pow(static_cast<double>(a), b));
}

template <typename T, typename MPType>
__device__ __forceinline__
typename std::enable_if<!std::is_integral<T>::value, T>::type
compute_pow(const T a, const T b) {
MPType a_val = static_cast<MPType>(a);
MPType b_val = static_cast<MPType>(b);
return static_cast<T>(pow(a_val, b_val));
typename std::enable_if<!std::is_integral<T>::value, MPType>::type
compute_pow(const T a, const MPType b) {
return pow(static_cast<MPType>(a), b);
}

template <typename T, typename MPType>
__device__ __forceinline__ typename std::enable_if<!std::is_integral<T>::value,
ComplexType<MPType>>::type
compute_pow(const ComplexType<T> a, const ComplexType<MPType> b) {
return pow(static_cast<ComplexType<MPType>>(a), b);
}

template <typename T>
struct CudaPowFunctor : public BaseActivationFunctor<T> {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
float factor;
struct BaseCudaPowFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
__device__ __forceinline__ T operator()(const T x) const {
return compute_pow<T, MT>(x, static_cast<T>(factor));
}
void SetFactor(double factor) { this->factor = static_cast<MPType>(factor); }
};

template <typename T>
struct CudaPowGradFunctor : public BaseActivationFunctor<T> {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
float factor;
struct BaseCudaPowGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
void SetFactor(double factor) { this->factor = static_cast<MPType>(factor); }
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaPowFunctor : public BaseCudaPowFunctor<T> {
__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(compute_pow(x, this->factor));
}
};

template <typename T>
struct CudaPowGradFunctor : public BaseCudaPowGradFunctor<T> {
// dx = dout * n * pow(x, n - 1)
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout * static_cast<T>(factor) *
compute_pow<T, MT>(x, static_cast<T>(factor - 1));
return dout *
static_cast<T>(this->factor * compute_pow(x, this->factor - 1));
}
};

template <typename T>
struct CudaPowGradFunctor<ComplexType<T>>
: public BaseCudaPowGradFunctor<ComplexType<T>> {
using MPType = typename phi::dtype::MPTypeTrait<ComplexType<T>>::Type;
MPType one = static_cast<MPType>(1.0f);

// dx = dout * (4 * (x*x*x))
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return dout * static_cast<ComplexType<T>>(
conj(this->factor * compute_pow(x, this->factor - one)));
}
};

template <typename T>
struct CudaCubeFunctor : public BaseActivationFunctor<T> {
// cube(x) = x * x * x
__device__ __forceinline__ T operator()(const T x) const { return x * x * x; }
};

template <typename T>
struct CudaCubeGradFunctor : public BaseActivationFunctor<T> {
T three = static_cast<T>(3.0f);

// dx = dout * 3 * x * x
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout * (three * (x * x));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaPowFunctor<ComplexType<T>>
struct CudaCubeGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
ComplexType<T> three = static_cast<ComplexType<T>>(3.0f);

// dx = dout * conj(3 * x * x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> x) const {
return pow(x, static_cast<ComplexType<T>>(factor));
const ComplexType<T> dout, const ComplexType<T> x) const {
return static_cast<ComplexType<T>>(dout * conj(three * (x * x)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaPowGradFunctor<ComplexType<T>>
struct CudaPow4GradFunctor : public BaseActivationFunctor<T> {
T four = static_cast<T>(4.0f);

// dx = dout * 4 * x * x * x
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout * (four * (x * x * x));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaPow4GradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
ComplexType<T> four = static_cast<ComplexType<T>>(4.0f);

// dx = dout * conj(4 * x * x * x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return static_cast<ComplexType<T>>(dout * conj(four * (x * x * x)));
}
// dx = dout * n * pow(x, n - 1)

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// for pow(x, 1.5)
template <typename T>
struct CudaPow1p5GradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

MPType f1p5 = static_cast<T>(1.5f);

// dx = dout * 1.5 * sqrt(x)
__device__ __forceinline__ T operator()(const T dout, const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return dout * static_cast<T>(f1p5 * sqrt(x));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaPow1p5GradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> f1p5 = static_cast<ComplexType<T>>(1.5f);

// dx = dout * conj(1.5 * sqrt(x))
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return dout * conj(static_cast<ComplexType<T>>(factor) *
pow(x, static_cast<ComplexType<T>>(factor - 1)));
return static_cast<ComplexType<T>>(dout * conj(f1p5 * sqrt(x)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,7 @@ struct ElementwiseInversePowFunctor<ComplexType<T>> {
inline HOSTDEVICE ComplexType<T> operator()(const ComplexType<T> a,
const ComplexType<T> b) const {
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
return pow(a, b);
return pow(b, a);
#else
return std::pow(static_cast<std::complex<T>>(b),
static_cast<std::complex<T>>(a));
Expand Down
Loading