Skip to content

Enhance device context pool #9293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,45 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/platform/device_context.h"
#include <unordered_set>
#include "paddle/fluid/memory/memory.h"

namespace paddle {
namespace platform {

DeviceContextPool* DeviceContextPool::pool = nullptr;

const platform::DeviceContext* DeviceContextPool::Get(
const platform::Place& place) {
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) const

auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
return it->second;
return it->second.get();
}

DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
using PtrType = std::unique_ptr<DeviceContext>;
std::unordered_set<Place, PlaceHash> set;
for (auto& p : places) {
set.insert(p);
}

for (auto& p : set) {
if (platform::is_cpu_place(p)) {
#ifdef PADDLE_WITH_MKLDNN
device_contexts_.emplace(places[i],
new platform::MKLDNNDeviceContext(
boost::get<platform::CPUPlace>(places[i])));
device_contexts_.emplace(
p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
#else
device_contexts_.emplace(places[i],
new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i])));
device_contexts_.emplace(
p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
#endif
} else if (platform::is_gpu_place(places[i])) {
} else if (platform::is_gpu_place(p)) {
#ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(places[i],
new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(places[i])));
device_contexts_.emplace(
p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
#else
PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
Expand Down Expand Up @@ -159,6 +161,7 @@ CUDADeviceContext::~CUDADeviceContext() {
Place CUDADeviceContext::GetPlace() const { return place_; }

void CUDADeviceContext::Wait() const {
std::lock_guard<std::mutex> guard(mutex_);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError());
}
Expand Down
18 changes: 4 additions & 14 deletions paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;

mutable std::mutex mutex_;
cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_;
Expand Down Expand Up @@ -159,7 +160,7 @@ class DeviceContextPool {
}

/*! \brief Return handle of single device context. */
const platform::DeviceContext* Get(const platform::Place& place);
platform::DeviceContext* Get(const platform::Place& place);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add const suffix?


template <typename Place>
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the const prefix

Expand All @@ -172,19 +173,8 @@ class DeviceContextPool {

private:
static DeviceContextPool* pool;
constexpr static int LEFT_SHIFT = 8;
struct Hash {
std::hash<int> hash_;
size_t operator()(const platform::Place& place) const {
int pre_hash = place.which() << LEFT_SHIFT;
if (platform::is_gpu_place(place)) {
pre_hash += boost::get<platform::CUDAPlace>(place).GetDeviceId();
}
return hash_(pre_hash);
}
};
std::unordered_map<const platform::Place, const platform::DeviceContext*,
Hash>
std::unordered_map<const platform::Place,
std::unique_ptr<platform::DeviceContext>, PlaceHash>
device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/platform/place.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ bool is_cpu_place(const Place &);
bool places_are_same_class(const Place &, const Place &);
bool is_same_place(const Place &, const Place &);

struct PlaceHash {
std::size_t operator()(const Place &p) const {
constexpr size_t num_dev_bits = 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 bit is not enough, the GPU box product has 32 cards in one node. Then will lead to an overlap of dev_id << num_dev_bits | p.which()

std::hash<int> ihash;
size_t dev_id = 0;
if (is_gpu_place(p)) {
dev_id = boost::get<CUDAPlace>(p).device;
}
return ihash(dev_id << num_dev_bits | p.which());
}
};

std::ostream &operator<<(std::ostream &, const Place &);

template <typename Visitor>
Expand Down