Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions paddle/fluid/operators/kernel_primitives/functor_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

#pragma once

#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/eigen_ext.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -74,16 +77,20 @@ struct IdentityFunctor {
*/
template <typename Tx, typename Ty = Tx>
struct DivideFunctor {
HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<Tx>(1.0f); }
private:
using MPType = typename ::paddle::operators::details::MPTypeTrait<Tx>::Type;

public:
HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<MPType>(1.0f); }

HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((Tx)(1.0 / n)) {}
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((MPType)(1.0 / n)) {}

HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x * n_inv);
return static_cast<Ty>(static_cast<MPType>(x) * n_inv);
}

private:
Tx n_inv;
MPType n_inv;
};

/**
Expand Down
53 changes: 24 additions & 29 deletions paddle/fluid/operators/mean_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,23 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n)
: n_inv(static_cast<T>(1.0 / n)) {}

HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }

private:
T n_inv;
};

template <typename T>
__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
Copy link
Contributor

@AnnaTrainingG AnnaTrainingG Dec 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前mean 和reducemean调用的都是pten里面的reduce,由chen tian yu进行修改。可以确认一下修改是否起作用了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已打log确认,确实能调到,因为PTen没有FP16的注册。后续PR会在PTen里添加注册。

using MT = typename details::MPTypeTrait<T>::Type;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
T data = in_data[0];
auto data = static_cast<MT>(in_data[0]);
for (; idx < N; idx += blockDim.x * gridDim.x) {
out_data[idx] = data / (static_cast<T>(N));
out_data[idx] = static_cast<T>(data / (static_cast<MT>(N)));
}
}

Expand All @@ -52,27 +45,29 @@ class MeanCUDAKernel : public framework::OpKernel<T> {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");

output->mutable_data<T>(context.GetPlace());
auto size_prob = input->numel();
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(context.GetPlace());
auto numel = input->numel();
auto rank = input->dims().size();
auto place = context.GetPlace();
auto stream = context.cuda_device_context().stream();

DivideFunctor<T> transformer(size_prob);
cub::TransformInputIterator<T, DivideFunctor<T>, const T*> trans_x(
in_data, transformer);
size_t temp_storage_bytes = 0;
if (rank == 0) { // scalar
auto gpu_place = BOOST_GET(platform::CUDAPlace, place);
memory::Copy(gpu_place, out_data, gpu_place, in_data, numel * sizeof(T),
stream);
return;
}

auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, trans_x,
out_data, size_prob, stream);
PADDLE_ENFORCE_GPU_SUCCESS(err);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
context.GetPlace());
err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, trans_x,
out_data, size_prob, stream);
PADDLE_ENFORCE_GPU_SUCCESS(err);
using MT = typename details::MPTypeTrait<T>::Type;
using Div = kernel_primitives::DivideFunctor<T, MT>;
std::vector<int> reduce_dims;
reduce_dims.reserve(rank);
for (decltype(rank) i = 0; i < rank; ++i) {
reduce_dims.push_back(i);
}
TensorReduceFunctorImpl<T, T, kernel_primitives::AddFunctor, Div>(
*input, output, Div(numel), reduce_dims, stream);
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/reduce_ops/reduce_functor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ struct CustomSub {

template <typename Tx, typename Ty = Tx>
struct CustomMean {
using Transformer = kps::DivideFunctor<Tx>;
using Transformer = kps::DivideFunctor<Tx, Ty>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在最新的reduce_op.cu.h中是没用到CustomMean的,本来计划删除,但是pten用到了,你只修改这里是不会对pten中reduce实现有修改的,不过他们好像在移动代码

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已打log确认,确实能调到,因为PTen没有FP16的注册。后续PR会在PTen里添加注册。


inline Ty initial() { return static_cast<Ty>(0.0f); }

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@
REGISTER_OP_CUDA_KERNEL(
reduce_mean,
ops::ReduceCudaKernel<bool, kps::AddFunctor, kps::DivideFunctor>,
ops::ReduceCudaKernel<paddle::platform::float16, kps::AddFunctor,
kps::DivideFunctor>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chentianyu说对于mean和sum pten内部是添加了注册的,请确认修改这里是否真正能够调用fp16

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已打log确认,确实能调到,因为PTen没有FP16的注册。后续PR会在PTen里添加注册。

ops::ReduceCudaKernel<float, kps::AddFunctor, kps::DivideFunctor>,
ops::ReduceCudaKernel<double, kps::AddFunctor, kps::DivideFunctor>);
13 changes: 13 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,18 @@ struct MeanGradFunctor {
}
};

// TODO(zengjinle): Should refine the numeric stability of FP16 reduce_mean
// and reduce_mean_grad later.
struct FP16MeanGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
dx->device(place) = (dy->template cast<float>().broadcast(dim) /
dx->template cast<float>().constant(size))
.template cast<platform::float16>();
}
};

} // namespace operators
} // namespace paddle
6 changes: 6 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ using CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::MeanGradFunctor, true>;

using FP16CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16, ops::FP16MeanGradFunctor,
true>;

REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<bool>,
FP16CUDAReduceMeanGradKernel,
CUDAReduceMeanGradKernel<float>,
CUDAReduceMeanGradKernel<double>);
57 changes: 38 additions & 19 deletions paddle/fluid/operators/reduce_ops/reduce_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/string/string_helper.h"

// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
Expand Down Expand Up @@ -814,11 +816,42 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
}
}

template <typename Tx, typename Ty, template <typename> class ReduceOp,
typename TransformOp>
static typename std::enable_if<!std::is_same<Tx, platform::float16>::value,
void>::type
CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data,
const TransformOp& transform, int reduce_num,
const platform::Place& place, gpuStream_t stream) {
auto reducer = ReduceOp<Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transform);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, reducer.initial(), stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}), place);
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, reducer.initial(), stream);
}

template <typename Tx, typename Ty, template <typename> class ReduceOp,
typename TransformOp>
static typename std::enable_if<std::is_same<Tx, platform::float16>::value,
void>::type
CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data,
const TransformOp& transform, int reduce_num,
const platform::Place& place, gpuStream_t stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Tx should not be float16 when using cub::DeviceReduce::Reduce()."));
}

template <typename Tx, typename Ty, template <typename> class ReduceOp,
typename TransformOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
const TransformOp& transform,
std::vector<int> origin_reduce_dims,
const std::vector<int>& origin_reduce_dims,
gpuStream_t stream) {
auto x_dim = framework::vectorize<int>(x.dims());
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
Expand Down Expand Up @@ -848,25 +881,11 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
}

config.SetOutputData(y_data, x.place(), &tmp);
bool use_cub_reduce = (config.reduce_num == numel) &&
(!std::is_same<Tx, paddle::platform::float16>::value);
constexpr bool kIsTxFP16 = std::is_same<Tx, paddle::platform::float16>::value;
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16;
if (use_cub_reduce) {
// launch CUB::Reduce
auto reducer = ReduceOp<Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transform);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, reducer.initial(),
stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
x.place());
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, reducer.initial(),
stream);

CubTensorReduceFunctorImpl<Tx, Ty, ReduceOp, TransformOp>(
x_data, y_data, transform, config.reduce_num, x.place(), stream);
return;
}

Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
std::vector<int> reduce_dims =
GetReduceDim(dims, input->dims().size(), reduce_all);
int reduce_num = 1;
for (int i = 0; i < input->dims().size(); i++) {
for (auto i : reduce_dims) {
reduce_num *= (input->dims())[i];
}
gpuStream_t stream = context.cuda_device_context().stream();
Expand All @@ -713,8 +713,10 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
TensorReduceFunc<T, ReduceOp, TransformOp>(
*input, output, reduce_dims, reduce_num, stream));
} else {
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, T>>(
*input, output, TransformOp<T, T>(reduce_num), reduce_dims, stream);
using MPType = typename details::MPTypeTrait<T>::Type;
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, MPType>>(
*input, output, TransformOp<T, MPType>(reduce_num), reduce_dims,
stream);
}
}
};
Expand Down
Loading