Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI decouple] move dropout_impl and cuda_graph_with_memory_pool from fluid to phi #49139

Merged
merged 12 commits into from
Dec 20, 2022
Prev Previous commit
Next Next commit
move cuda_graph_with_memory_pool from fluid to phi
  • Loading branch information
huangjiyi committed Dec 17, 2022
commit bfbfcd8601b0323e093be6ccd80c666153e36810
7 changes: 4 additions & 3 deletions paddle/fluid/platform/cuda_graph_with_memory_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"

#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"

DECLARE_bool(use_stream_safe_cuda_allocator);

namespace paddle {
namespace platform {

#ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(platform::CUDAPlace place,
void BeginCUDAGraphCapture(phi::GPUPlace place,
cudaStreamCaptureMode mode,
int64_t pool_id) {
auto* mutable_dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto* mutable_dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
dev_ctx->cudnn_workspace_handle().ResetWorkspace();

Expand Down
107 changes: 11 additions & 96 deletions paddle/fluid/platform/cuda_graph_with_memory_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,123 +14,38 @@

#pragma once

#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
#endif

namespace paddle {
namespace platform {

#ifdef PADDLE_WITH_CUDA
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \
__kernel_func, \
__grid, \
__block, \
__sm_size, \
__stream, \
__seed_inc, \
__seed_expr, \
__offset_expr, \
...) \
do { \
if (::paddle::platform::CUDAGraph::IsThisThreadCapturing() && (__cond)) { \
using __Helper = \
::paddle::platform::IsSameKernelHelper<decltype(&__kernel_func), \
&__kernel_func>; \
auto *dev_ctx = \
::paddle::platform::DeviceContextPool::Instance().GetByPlace( \
::paddle::platform::CUDAGraph::CapturingPlace()); \
auto __set_seed_func = \
[=](::paddle::platform::CUDAKernelParams *__params, \
bool __check_only) -> bool { \
if (__check_only) { \
return __params->func() == &__kernel_func && \
__Helper::Compare(*__params, __VA_ARGS__); \
} \
auto &KERNEL_PARAMS = *__params; \
uint64_t __seed, __offset; \
::paddle::operators::GetSeedDataAndIncrement( \
*dev_ctx, nullptr, false, 0, __seed_inc, &__seed, &__offset); \
__seed_expr = static_cast<decltype(__seed_expr)>(__seed); \
__offset_expr = static_cast<decltype(__offset_expr)>(__offset); \
return true; \
}; \
::paddle::platform::CUDAGraph::RecordRandomKernelInfo(__set_seed_func); \
} \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#else
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \
__kernel_func, \
__grid, \
__block, \
__sm_size, \
__stream, \
__seed_inc, \
__seed_expr, \
__offset_expr, \
...) \
do { \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#endif
using CUDAGraph = phi::backends::gpu::CUDAGraph;

// NOTE: These APIs are not thread-safe.
#ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(platform::CUDAPlace place,
void BeginCUDAGraphCapture(phi::GPUPlace place,
cudaStreamCaptureMode mode,
int64_t pool_id = CUDAGraph::kInvalidPoolID);
std::unique_ptr<CUDAGraph> EndCUDAGraphCapture();
#endif

inline bool IsCUDAGraphCapturing() {
#ifdef PADDLE_WITH_CUDA
return CUDAGraph::IsCapturing();
#else
return false;
#endif
}

inline platform::CUDAPlace CUDAGraphCapturingPlace() {
inline phi::GPUPlace CUDAGraphCapturingPlace() {
#ifdef PADDLE_WITH_CUDA
return CUDAGraph::CapturingPlace();
#else
PADDLE_THROW(platform::errors::Unimplemented(
PADDLE_THROW(phi::errors::Unimplemented(
"CUDA Graph is only supported on NVIDIA GPU device."));
#endif
}

// Add reset callback if CUDA Graph is capturing.
// Otherwise, invoke callback directly.
template <typename Callback>
inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) {
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) {
return CUDAGraph::AddResetCallbackDuringCapturing(
std::forward<Callback>(callback));
}
#endif
callback();
}

template <typename T>
inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) {
static_assert(std::is_trivial<T>::value, "T must be trivial type");
static_assert(!std::is_same<T, void>::value, "T cannot be void");
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) {
size_t nbytes = size * sizeof(T);
void *new_host_mem = new uint8_t[nbytes];
std::memcpy(new_host_mem, host_mem, nbytes);
AddResetCallbackIfCapturingCUDAGraph(
[new_host_mem] { delete[] reinterpret_cast<uint8_t *>(new_host_mem); });
return reinterpret_cast<T *>(new_host_mem);
}
#endif
return host_mem;
}
using phi::backends::gpu::AddResetCallbackIfCapturingCUDAGraph;
using phi::backends::gpu::IsCUDAGraphCapturing;
using phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph;

class SkipCUDAGraphCaptureGuard {
DISABLE_COPY_AND_ASSIGN(SkipCUDAGraphCaptureGuard);
Expand Down
49 changes: 1 addition & 48 deletions paddle/fluid/platform/device/gpu/cuda/cuda_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,7 @@ using cudaStreamCaptureMode = phi::backends::gpu::cudaStreamCaptureMode;
#endif
using CUDAGraph = phi::backends::gpu::CUDAGraph;
using CUDAGraphCaptureModeGuard = phi::backends::gpu::CUDAGraphCaptureModeGuard;

template <typename T>
static bool IsBitwiseEqual(const T &x, const T &y) {
return std::memcmp(&x, &y, sizeof(T)) == 0;
}

template <typename F, F f>
struct IsSameKernelHelper;

template <typename Return,
typename... FuncArgs,
Return (*kernel_fn)(FuncArgs...)>
struct IsSameKernelHelper<Return (*)(FuncArgs...), kernel_fn> {
private:
using FuncArgsTuple = decltype(std::make_tuple(std::declval<FuncArgs>()...));

template <typename TupleT, size_t IDX, bool IsEnd /*=false*/>
struct Impl {
static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
using CompareT = typename std::tuple_element<IDX, FuncArgsTuple>::type;
if (!IsBitwiseEqual<CompareT>(params.As<CompareT>(IDX),
std::get<IDX>(args))) {
return false;
}

constexpr auto NewIsEnd = (IDX + 1 == std::tuple_size<TupleT>::value);
return Impl<TupleT, IDX + 1, NewIsEnd>::Compare(params, args);
}
};

template <typename TupleT, size_t IDX>
struct Impl<TupleT, IDX, true> {
static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
return true;
}
};

public:
template <typename... Args>
static bool Compare(const CUDAKernelParams &params, Args... args) {
constexpr auto kNumArgs = sizeof...(FuncArgs);
static_assert(kNumArgs == sizeof...(Args), "Argument number not match");

auto args_tuple = std::make_tuple(args...);
using TupleT = typename std::decay<decltype(args_tuple)>::type;
return Impl<TupleT, 0, kNumArgs == 0>::Compare(params, args_tuple);
}
};
using IsSameKernelHelper = phi::backends::gpu::IsSameKernelHelper;

} // namespace platform
} // namespace paddle
48 changes: 48 additions & 0 deletions paddle/phi/backends/gpu/cuda/cuda_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,54 @@ class CUDAGraphCaptureModeGuard {
};
#endif

template <typename T>
static bool IsBitwiseEqual(const T &x, const T &y) {
return std::memcmp(&x, &y, sizeof(T)) == 0;
}

template <typename F, F f>
struct IsSameKernelHelper;

template <typename Return,
typename... FuncArgs,
Return (*kernel_fn)(FuncArgs...)>
struct IsSameKernelHelper<Return (*)(FuncArgs...), kernel_fn> {
private:
using FuncArgsTuple = decltype(std::make_tuple(std::declval<FuncArgs>()...));

template <typename TupleT, size_t IDX, bool IsEnd /*=false*/>
struct Impl {
static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
using CompareT = typename std::tuple_element<IDX, FuncArgsTuple>::type;
if (!IsBitwiseEqual<CompareT>(params.As<CompareT>(IDX),
std::get<IDX>(args))) {
return false;
}

constexpr auto NewIsEnd = (IDX + 1 == std::tuple_size<TupleT>::value);
return Impl<TupleT, IDX + 1, NewIsEnd>::Compare(params, args);
}
};

template <typename TupleT, size_t IDX>
struct Impl<TupleT, IDX, true> {
static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
return true;
}
};

public:
template <typename... Args>
static bool Compare(const CUDAKernelParams &params, Args... args) {
constexpr auto kNumArgs = sizeof...(FuncArgs);
static_assert(kNumArgs == sizeof...(Args), "Argument number not match");

auto args_tuple = std::make_tuple(args...);
using TupleT = typename std::decay<decltype(args_tuple)>::type;
return Impl<TupleT, 0, kNumArgs == 0>::Compare(params, args_tuple);
}
};

} // namespace gpu
} // namespace backends
} // namespace phi
120 changes: 120 additions & 0 deletions paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#ifdef PADDLE_WITH_CUDA
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
#include "paddle/phi/kernels/funcs/dropout_impl_util.h"
#endif

namespace phi {
namespace backends {
namespace gpu {

#ifdef PADDLE_WITH_CUDA
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \
__kernel_func, \
__grid, \
__block, \
__sm_size, \
__stream, \
__seed_inc, \
__seed_expr, \
__offset_expr, \
...) \
do { \
if (CUDAGraph::IsThisThreadCapturing() && (__cond)) { \
using __Helper = \
phi::backends::gpu::IsSameKernelHelper<decltype(&__kernel_func), \
&__kernel_func>; \
auto *dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get( \
CUDAGraph::CapturingPlace()); \
auto __set_seed_func = \
[=](phi::backends::gpu::CUDAKernelParams *__params, \
bool __check_only) -> bool { \
if (__check_only) { \
return __params->func() == &__kernel_func && \
__Helper::Compare(*__params, __VA_ARGS__); \
} \
auto &KERNEL_PARAMS = *__params; \
uint64_t __seed, __offset; \
::paddle::operators::GetSeedDataAndIncrement( \
*dev_ctx, nullptr, false, 0, __seed_inc, &__seed, &__offset); \
__seed_expr = static_cast<decltype(__seed_expr)>(__seed); \
__offset_expr = static_cast<decltype(__offset_expr)>(__offset); \
return true; \
}; \
CUDAGraph::RecordRandomKernelInfo(__set_seed_func); \
} \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#else
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \
__kernel_func, \
__grid, \
__block, \
__sm_size, \
__stream, \
__seed_inc, \
__seed_expr, \
__offset_expr, \
...) \
do { \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#endif

inline bool IsCUDAGraphCapturing() {
#ifdef PADDLE_WITH_CUDA
return CUDAGraph::IsCapturing();
#else
return false;
#endif
}

// Add reset callback if CUDA Graph is capturing.
// Otherwise, invoke callback directly.
template <typename Callback>
inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) {
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) {
return CUDAGraph::AddResetCallbackDuringCapturing(
std::forward<Callback>(callback));
}
#endif
callback();
}

template <typename T>
inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) {
static_assert(std::is_trivial<T>::value, "T must be trivial type");
static_assert(!std::is_same<T, void>::value, "T cannot be void");
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) {
size_t nbytes = size * sizeof(T);
void *new_host_mem = new uint8_t[nbytes];
std::memcpy(new_host_mem, host_mem, nbytes);
AddResetCallbackIfCapturingCUDAGraph(
[new_host_mem] { delete[] reinterpret_cast<uint8_t *>(new_host_mem); });
return reinterpret_cast<T *>(new_host_mem);
}
#endif
return host_mem;
}

} // namespace gpu
} // namespace backends
} // namespace phi
Loading