-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Refine device context and fix GetPlace() #3084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -20,12 +20,96 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() | |||
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device(); | |||
} | |||
|
|||
CPUDeviceContext::CPUDeviceContext() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move implementation into *.cc
to make the definition clearer.
} | ||
|
||
CUDADeviceContext::~CUDADeviceContext() { | ||
SetDeviceId(place_.device); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add SetDeviceId(place_.device);
in here
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); | ||
} | ||
|
||
Place CUDADeviceContext::GetPlace() const { return place_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix bug in here, return place_
eigen_device_.reset(new Eigen::DefaultDevice()); | ||
} | ||
|
||
CPUDeviceContext::CPUDeviceContext(CPUPlace place) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep unified cpu and gpu interface
|
||
cublasHandle_t blas_handle_{nullptr}; | ||
private: | ||
uint64_t seed_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to set seed to uint64_t
and I will also add rand method for CPU in next PR
@@ -56,7 +56,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place, | |||
const void* src, size_t num, | |||
cudaStream_t stream) { | |||
if (dst_place == src_place) { | |||
platform::GPUPlaceGuard g(src_place.device); | |||
platform::SetDeviceId(src_place.device); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个SetDeviceId是不是要写在if外面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个不需要了 因为GpuMemcpyPeer函数里面的两个参数 就是相应的device id
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
No description provided.