Skip to content

Commit

Permalink
fix cuda graph (#51648)
Browse files Browse the repository at this point in the history
  • Loading branch information
pangyoki authored Mar 15, 2023
1 parent 4283e19 commit 53c73c7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 56 deletions.
12 changes: 10 additions & 2 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ void InterpreterCore::PrepareForCUDAGraphCapture() {
platform::IsCUDAGraphCapturing(),
false,
platform::errors::PermissionDenied("CUDA Graph is not allowed to capture "
"when running the first batch."));
"before prepare."));
PADDLE_ENFORCE_EQ(platform::is_gpu_place(place_),
true,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -684,8 +684,16 @@ void InterpreterCore::Convert(
if (op_type == interpreter::kMemcpyD2H ||
op_type == interpreter::kMemcpyH2D) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"op_type can't be memcpy d2h or h2d while using cuda graph."));
"Cuda memory copy d2h/h2d is not allowed while using cuda graph."));
}
PADDLE_ENFORCE_EQ(typeid(*dev_ctx_) == typeid(phi::GPUContext),
true,
platform::errors::InvalidArgument(
"Device context of op %s must be [%s] while using "
"cuda graph, but got [%s].",
op_type,
typeid(phi::GPUContext).name(),
typeid(*dev_ctx_).name()));
// cuda graph needs to record all stream
phi::backends::gpu::CUDAGraphContextManager::Instance()
.RecordCapturingDeviceContext(dev_ctx_);
Expand Down
105 changes: 61 additions & 44 deletions paddle/fluid/platform/cuda_graph_with_memory_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,45 +40,68 @@ void InitCUDNNRelatedHandle(phi::GPUContext* dev_ctx) {
dev_ctx->cusolver_dn_handle();
}

phi::DeviceContext* SelectCUDAGraphDeviceContext(phi::GPUPlace place,
int64_t* pool_id) {
phi::DeviceContext* mutable_dev_ctx;
auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts();
auto num_stream = all_capturing_dev_ctxs.size();
if (num_stream > 0) {
// Capturing device contexts will only be recorded in new
// executor in temporary, that is,
// FLAGS_new_executor_use_cuda_graph needs to be set to True.
// This restriction can be removed if device context is
// recorded in other modes.
// Record method: RecordCapturingDeviceContext.
PADDLE_ENFORCE_EQ(FLAGS_new_executor_use_cuda_graph,
true,
platform::errors::InvalidArgument(
"FLAGS_new_executor_use_cuda_graph must be True when "
"capturing stream is recorded."));
if (num_stream > 1) {
VLOG(4) << "Use a new stream to capture cuda graph. Used in multi-stream "
"scenarios with new executor.";
if (*pool_id <= CUDAGraph::kInvalidPoolID) {
*pool_id = CUDAGraph::UniqueMemoryPoolID();
}
mutable_dev_ctx =
phi::backends::gpu::CUDAGraphContextManager::Instance().Get(
*pool_id, place, 0);
} else if (num_stream == 1) {
VLOG(4) << "Use recorded stream to capture cuda graph. Used in "
"single-stream scenarios with new executor.";
mutable_dev_ctx = *(all_capturing_dev_ctxs.begin());
}
} else {
VLOG(4) << "Use default stream to capture cuda graph.";
mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
}
return mutable_dev_ctx;
}

void BeginCUDAGraphCapture(phi::GPUPlace place,
cudaStreamCaptureMode mode,
int64_t pool_id) {
auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
auto* mutable_dev_ctx = SelectCUDAGraphDeviceContext(place, &pool_id);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
InitCUDNNRelatedHandle(dev_ctx);

auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts();
// create_cuda_graph_stream: Whether to create a new stream to
// capture cuda graph, usually used in multi-stream scenarios.
// Can only be used for new executor in static mode, that is,
// FLAGS_new_executor_use_cuda_graph needs to be set to True.
bool create_cuda_graph_stream = false;
if (FLAGS_new_executor_use_cuda_graph &&
(all_capturing_dev_ctxs.size() > 1 ||
(all_capturing_dev_ctxs.size() == 1 &&
(*(all_capturing_dev_ctxs.begin()) != mutable_dev_ctx)))) {
create_cuda_graph_stream = true;
}
if (create_cuda_graph_stream) {
VLOG(4) << "create a new stream to capture cuda graph.";
if (pool_id <= CUDAGraph::kInvalidPoolID) {
pool_id = CUDAGraph::UniqueMemoryPoolID();
}
mutable_dev_ctx =
phi::backends::gpu::CUDAGraphContextManager::Instance().Get(
pool_id, place, 0);
auto num_stream = all_capturing_dev_ctxs.size();
if (num_stream > 1) {
for (auto iter = all_capturing_dev_ctxs.begin();
iter != all_capturing_dev_ctxs.end();
++iter) {
auto* capturing_dev_ctx = reinterpret_cast<phi::GPUContext*>(*iter);
InitCUDNNRelatedHandle(capturing_dev_ctx);
}
}
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
InitCUDNNRelatedHandle(dev_ctx);

auto stream = dev_ctx->stream();
CUDAGraph::BeginCapture(place, stream, mode);
CUDAGraph::SetIsCUDAGraphStreamCreated(create_cuda_graph_stream);

// When using cuda graph in new executor, fast GC must be used.
// FLAGS_use_stream_safe_cuda_allocator should be true.
Expand All @@ -96,7 +119,7 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
if (old_value) {
FLAGS_use_stream_safe_cuda_allocator = true;
}
if (create_cuda_graph_stream) {
if (num_stream > 1) {
// Set cuda graph allocator for all streams.
// Establish dependencies between cuda graph stream and all other streams
// using eventWait, so that all streams will be captured.
Expand Down Expand Up @@ -129,20 +152,17 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
}

std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() {
phi::DeviceContext* mutable_dev_ctx;
auto place = CUDAGraph::CapturingPlace();
bool create_cuda_graph_stream = CUDAGraph::IsCUDAGraphStreamCreated();
if (create_cuda_graph_stream) {
auto pool_id = CUDAGraph::CapturingPoolID();
auto* mutable_dev_ctx = SelectCUDAGraphDeviceContext(place, &pool_id);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);

auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts();
auto num_stream = all_capturing_dev_ctxs.size();
if (num_stream > 1) {
// join all other streams back to origin cuda graph stream.
int64_t pool_id = CUDAGraph::CapturingPoolID();
mutable_dev_ctx =
phi::backends::gpu::CUDAGraphContextManager::Instance().Get(
pool_id, place, 0);
auto* cuda_graph_dev_ctx =
reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts();
for (auto iter = all_capturing_dev_ctxs.begin();
iter != all_capturing_dev_ctxs.end();
++iter) {
Expand All @@ -152,19 +172,16 @@ std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() {
capturing_dev_ctx->GetPlace(),
platform::GenerateDeviceEventFlag());
capturing_event->Record(capturing_dev_ctx);
capturing_event->Wait(platform::kCUDA, cuda_graph_dev_ctx);
VLOG(4) << "CUDA Graph stream eventWait. cuda graph dev_ctx: "
<< cuda_graph_dev_ctx
capturing_event->Wait(platform::kCUDA, dev_ctx);
VLOG(4) << "CUDA Graph stream eventWait. cuda graph dev_ctx: " << dev_ctx
<< " wait for capturing dev_ctx: " << capturing_dev_ctx;
capturing_dev_ctx->cudnn_workspace_handle().ResetWorkspace();
capturing_dev_ctx->SetCUDAGraphAllocator(nullptr);
}
phi::backends::gpu::CUDAGraphContextManager::Instance()
.ClearDeviceContextsRecords();
} else {
mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
}
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);

phi::backends::gpu::CUDAGraphContextManager::Instance()
.ClearDeviceContextsRecords();
dev_ctx->cudnn_workspace_handle().ResetWorkspace();
dev_ctx->SetCUDAGraphAllocator(nullptr);
return CUDAGraph::EndCapture();
Expand Down
10 changes: 0 additions & 10 deletions paddle/phi/backends/gpu/cuda/cuda_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,6 @@ class CUDAGraph {
// supported during capturing CUDA Graph.
static bool IsValidCapturing();

static void SetIsCUDAGraphStreamCreated(bool create_cuda_graph_stream) {
capturing_graph_->is_cuda_graph_stream_created_ = create_cuda_graph_stream;
}

static bool IsCUDAGraphStreamCreated() {
return capturing_graph_->is_cuda_graph_stream_created_;
}

static bool IsThreadLocalCapturing() {
#if CUDA_VERSION >= 10010
return IsCapturing() &&
Expand Down Expand Up @@ -254,8 +246,6 @@ class CUDAGraph {

bool is_first_run_{true};

bool is_cuda_graph_stream_created_{false};

static paddle::optional<std::thread::id> capturing_thread_id_;
static std::unique_ptr<CUDAGraph> capturing_graph_;
};
Expand Down

0 comments on commit 53c73c7

Please sign in to comment.