Skip to content

Commit

Permalink
Added locking for thread safety
Browse files Browse the repository at this point in the history
  • Loading branch information
artyom-beilis committed Feb 3, 2022
1 parent 032a693 commit 437e743
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
5 changes: 5 additions & 0 deletions src/CLTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ namespace ptdlprim {
if(allocated_size == 0 && debug_allocator)
setlocale(LC_ALL,"");

std::unique_lock<std::mutex> g(lock);

std::int64_t size = round(orig_size);
std::unique_ptr<CLMemAllocation> res;

Expand Down Expand Up @@ -53,6 +55,8 @@ namespace ptdlprim {
}
void CLCache::release(std::unique_ptr<CLMemAllocation> &&mem)
{
std::unique_lock<std::mutex> g(lock);

int64_t size = mem->size;
cached_size += mem->size;
requested_size -= mem->orig_size;
Expand All @@ -63,6 +67,7 @@ namespace ptdlprim {

void CLCache::clear()
{
std::unique_lock<std::mutex> g(lock);
{
allocation_type tmp;
tmp.swap(allocation);
Expand Down
20 changes: 12 additions & 8 deletions src/CLTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ namespace ptdlprim {

class CLCache {
public:
CLCache() {}

CLCache(CLCache const &) = delete;
void operator=(CLCache const &) = delete;
typedef std::map<std::int64_t,std::list<std::unique_ptr<CLMemAllocation> > > allocation_type;
std::mutex lock;
allocation_type allocation;

std::int64_t allocated_size = 0;
Expand Down Expand Up @@ -66,10 +71,8 @@ namespace ptdlprim {
~CLContextManager()
{
{
std::vector<DevData> tmp;
tmp.swap(data_);
for(DevData &data:tmp)
data.cache.clear();
for(auto &data:data_)
data->cache.clear();
}
no_cache_ = true;
}
Expand Down Expand Up @@ -162,8 +165,9 @@ namespace ptdlprim {
continue;
}
for(size_t j=0;j<devices.size();j++) {
data_.push_back(DevData());
data_.back().name = std::to_string(i) + ":" + std::to_string(j);
std::unique_ptr<DevData> d(new DevData());
data_.push_back(std::move(d));
data_.back()->name = std::to_string(i) + ":" + std::to_string(j);
}
}
}
Expand All @@ -174,7 +178,7 @@ namespace ptdlprim {
i = 0;
if(i >= int(data_.size()))
throw std::runtime_error("Invalid Device #" + std::to_string(i));
DevData &res = data_[i];
DevData &res = *data_[i];
if(res.ready)
return res;
res.ctx=dlprim::Context(res.name);
Expand All @@ -186,7 +190,7 @@ namespace ptdlprim {
}


std::vector<DevData> data_;
std::vector<std::unique_ptr<DevData> > data_;
bool no_cache_;
};

Expand Down

0 comments on commit 437e743

Please sign in to comment.