-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add DeviceContext design doc #2648
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
## Context Design | ||
|
||
|
||
A Net is executed by single or several threads. A Context is related to a thread and records necessary runtime resources. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. records -> holds? sounds weird There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
Context is defined as follows: | ||
|
||
``` | ||
class Context {}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. struct Context; reference mxnet::Context There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no common methods shared between GpuContext and CpuContext? struct Context {
int dev_id{0}; // CPU = 0
enum DevType {
kCPU;
kGPU;
};
DevType dev_type;
enum StreamType {
kCUDNN;
kBLAS;
kCUDA;
};
// all the streams are created globally, so a void* is enough
// idx is the index of stream of `type`, each thread in a device can have one idx.
void* GetStream(StreamType type, int stream_idx);
// allocate one stream for each StreamType and thread, in all the devices.
static void** streams;
}; Both And we can merge them into one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
``` | ||
|
||
### CUDAContext | ||
|
||
Because the Tensor computation are executed by Eigen library, which needs an Eigen::GpuDevice type object as parameter. And the GpuDevice parameter is constructed with an Eigen::CudaStreamDevice object. We need to set a specific GpuID and CudaStream to create a Eigen::CudaStream object. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -> which needs an Eigen::GpuDevice type object as a parameter . And -> , and We need to set both a specific GpuID and an CudaStream to create a Eigen::CudaStream object. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
At the same time, some computation work will executed by cublas or cudnn library. Take cublas library as an example, we have to acquire a cublasHandle which binds on a CudaStream to make computation. It's the same way as Eigen library does. | ||
|
||
The future DAGNet is run by multi-threads. And each thread will have its own Eigen::GpuDevice object binding on different CudaStream. Multi-threads can run parallelly on a same GPU card. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all ". And" -> ", and" on different CudaStream. Multi-threads can run parallelly on a same GPU card. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
And Copy(Communication) work will be in charge of specific thread. The copy thread will only get CudaStream from corresponding Context. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not quite understand about this line, is it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, |
||
|
||
We can make a summary: | ||
|
||
- Differnet GPU cards have different GpuID, and we can do data parallelism on multi-GPUs. | ||
- Multi-threads can run a Net parallelly on a single GPU card, and each thread has one Context. | ||
- There is also single thread executing a Net sequentially. All computation and communication work will use same Context. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use same -> use the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
|
||
CUDAContext is defined as follows:: | ||
|
||
|
||
```c++ | ||
class DeviceGuard { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we restrict that a net runs on a single device, and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DeviceGuard is just use for reduce the memory burden of developers. The DeviceGuard ensure we call some CUDA API in right GPU device. |
||
public: | ||
explicit DeviceGuard(int newDevice) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should DeviceGuard take
It could make our design consistent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
: previous_(GetCurrentGPUID()) { | ||
if (previous_ != newDevice) { | ||
cudaError_t err = cudaSetDevice(newDevice); | ||
PADDLE_ASSERT(err == cudaSuccess); | ||
} | ||
} | ||
|
||
~DeviceGuard() noexcept { | ||
cudaError_t err = cudaSetDevice(previous_); | ||
PADDLE_ASSERT(err == cudaSuccess); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is obviously not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
|
||
private: | ||
int previous_; | ||
}; | ||
|
||
class CUDAContext : public Context{ | ||
public: | ||
explicit CDUAContext(const int gpu_id) : gpu_id_(gpu_id) { | ||
DeviceGuard(gpu_id_); | ||
cudaError_t err = cudaStreamCreate(&stream_); | ||
PADDLE_ASSERT(err == cudaSuccess); | ||
|
||
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); | ||
eigen_handle_ = new Eigen::GpuDevice(eigen_stream_); | ||
} | ||
|
||
void Wait() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This two-line function is only called by the destructor, how about move these two lines into the destructor and save this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Wait method is to synchronize a CUDA stream. And this method may be called also in other circumstances. |
||
cudaError_t err = cudaStreamSynchronize(stream_); | ||
PADDLE_ASSERT(err == cudaSuccess); | ||
} | ||
|
||
cudaStream_t GetStream() { | ||
return stream_; | ||
} | ||
|
||
Eigen::GpuDevice GetEigenHandle() { | ||
return *eigen_handle_; | ||
} | ||
|
||
cublasHandle_t GetBlasHandle() { | ||
if (!blas_handle_) { | ||
DeviceGuard guard(gpu_id_); | ||
cudaError_t err = cublasCreate(&blas_handle_); | ||
PADDLE_ASSERT(err == CUBLAS_STATUS_SUCCESS); | ||
cudaError_t err = cublasSetStream(blas_handle_, stream_); | ||
PADDLE_ASSERT(err == cudaSuccess); | ||
} | ||
return blas_handle_; | ||
} | ||
|
||
cudnnHandle_t GetDnnHandle() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe GetCUDNNHandle is more clear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
if (!dnn_handle_) { | ||
DeviceGuard guard(gpu_id_); | ||
cudaError_t err = cudnnCreate(&dnn_handle_); | ||
PADDLE_ASSERT(err == CUDNN_STATUS_SUCCESS); | ||
cudaError_t err = cudnnSetStream(dnn_handle_, stream_); | ||
PADDLE_ASSERT(err == cudaSuccess); | ||
} | ||
return dnn_handle_; | ||
} | ||
|
||
curandGenerator_t GetRandHandle() { | ||
if (! rand_handle_) { | ||
DeviceGuard guard(gpu_id_); | ||
cudaError_t err = curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT); | ||
PADDLE_ASSERT(err == CURAND_STATUS_SUCCESS); | ||
cudaError_t err = curandSetPseudoRandomGeneratorSeed(curand_generator_, random_seed_); | ||
PADDLE_ASSERT(err == CURAND_STATUS_SUCCESS); | ||
cudaError_t err = curandSetStream(curand_generator_, stream_); | ||
PADDLE_ASSERT(err == cudaSuccess); | ||
} | ||
return rand_handle_; | ||
} | ||
|
||
~CUDAContext() { | ||
Wait(); | ||
cudaError_t err = cudaStreamDestroy(stream_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe the destructor of the stream should be placed at the end. Because this is the first to be constructed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
PADDLE_ASSERT(err == cudaSuccess); | ||
|
||
if (blas_handle_) { | ||
cudaError_t err = cublasDestroy(blas_handle_); | ||
PADDLE_ASSERT(err == CUBLAS_STATUS_SUCCESS); | ||
} | ||
|
||
if (dnn_handle_) { | ||
cudaError_t err = cudnnDestroy(dnn_handle_); | ||
PADDLE_ASSERT(err == CUDNN_STATUS_SUCCESS); | ||
} | ||
|
||
if (rand_handle_) { | ||
cudaError_t err = curandDestroyGenerator(rand_handle_); | ||
PADDLE_ASSERT(err == CURAND_STATUS_SUCCESS); | ||
} | ||
|
||
delete eigen_stream_; | ||
delete eigen_handle_; | ||
} | ||
|
||
private: | ||
int gpu_id_; | ||
cudaStream_t stream_; | ||
|
||
Eigen::CudaStreamDevice* eigen_stream_; | ||
Eigen::GpuDevice* eigen_handle_; | ||
|
||
cublasHandle_t blas_handle_{nullptr}; | ||
|
||
cudnnHandle_t dnn_handle_{nullptr}; | ||
|
||
int random_seed_; | ||
curandGenerator_t rand_handle_{nullptr}; | ||
}; | ||
``` | ||
|
||
### CPUContext | ||
|
||
The CPUContext is defined as follows: | ||
|
||
```c++ | ||
class CPUContext : public Context{ | ||
Eigen::DefaultDevice GetEigenHandle() { | ||
if (!eigen_handle_) { | ||
eigen_handle_ = new Eigen::DefaultDevice(); | ||
} | ||
return *eigen_handle_; | ||
} | ||
private: | ||
Eigen::DefaultDevice* eigen_handle_{nullptr}; | ||
}; | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no explanation here why should Context, should point out the relationship between Context and Operator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done