-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Enhance device context pool #9293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,7 @@ class CUDADeviceContext : public DeviceContext { | |
std::unique_ptr<Eigen::GpuDevice> eigen_device_; | ||
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; | ||
|
||
mutable std::mutex mutex_; | ||
cudaStream_t stream_; | ||
cudnnHandle_t cudnn_handle_; | ||
cublasHandle_t cublas_handle_; | ||
|
@@ -159,7 +160,7 @@ class DeviceContextPool { | |
} | ||
|
||
/*! \brief Return handle of single device context. */ | ||
const platform::DeviceContext* Get(const platform::Place& place); | ||
platform::DeviceContext* Get(const platform::Place& place); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should add const suffix? |
||
|
||
template <typename Place> | ||
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the const prefix |
||
|
@@ -172,19 +173,8 @@ class DeviceContextPool { | |
|
||
private: | ||
static DeviceContextPool* pool; | ||
constexpr static int LEFT_SHIFT = 8; | ||
struct Hash { | ||
std::hash<int> hash_; | ||
size_t operator()(const platform::Place& place) const { | ||
int pre_hash = place.which() << LEFT_SHIFT; | ||
if (platform::is_gpu_place(place)) { | ||
pre_hash += boost::get<platform::CUDAPlace>(place).GetDeviceId(); | ||
} | ||
return hash_(pre_hash); | ||
} | ||
}; | ||
std::unordered_map<const platform::Place, const platform::DeviceContext*, | ||
Hash> | ||
std::unordered_map<const platform::Place, | ||
std::unique_ptr<platform::DeviceContext>, PlaceHash> | ||
device_contexts_; | ||
DISABLE_COPY_AND_ASSIGN(DeviceContextPool); | ||
}; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,6 +65,18 @@ bool is_cpu_place(const Place &); | |
bool places_are_same_class(const Place &, const Place &); | ||
bool is_same_place(const Place &, const Place &); | ||
|
||
struct PlaceHash { | ||
std::size_t operator()(const Place &p) const { | ||
constexpr size_t num_dev_bits = 4; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 4 bit is not enough, the GPU box product has 32 cards in one node. Then will lead to an overlap of |
||
std::hash<int> ihash; | ||
size_t dev_id = 0; | ||
if (is_gpu_place(p)) { | ||
dev_id = boost::get<CUDAPlace>(p).device; | ||
} | ||
return ihash(dev_id << num_dev_bits | p.which()); | ||
} | ||
}; | ||
|
||
std::ostream &operator<<(std::ostream &, const Place &); | ||
|
||
template <typename Visitor> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.