Skip to content

Commit

Permalink
Merge pull request pytorch#709 from colesbury/pinned_memory
Browse files Browse the repository at this point in the history
Fix bug where pinned memory event could be recorded on incorrect device
  • Loading branch information
soumith authored Feb 28, 2017
2 parents 4ef3036 + 76de151 commit 8dfcf7e
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 94 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
THCTensorRandom.cpp
THCCachingAllocator.cpp
THCCachingHostAllocator.cpp
THCStream.cpp
PROPERTIES COMPILE_FLAGS -std=${CXX_VERSION})
ELSE()
SET(CMAKE_CXX_STANDARD 11)
Expand Down Expand Up @@ -158,7 +159,7 @@ SET(src
THCCachingHostAllocator.cpp
THCGeneral.c
THCStorageCopy.c
THCStream.c
THCStream.cpp
THCTensor.c
THCTensorCopy.c
THCTensorRandom.cpp
Expand Down
82 changes: 44 additions & 38 deletions THCCachingHostAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

#include <cuda_runtime_api.h>
#include <deque>
#include <memory>
#include <mutex>
#include <set>
#include <stdint.h>
#include <unordered_map>
#include <unordered_set>
#include <utility>


namespace {

typedef std::shared_ptr<THCStream> THCStreamPtr;

struct BlockSize
{
size_t size; // allocation size
Expand All @@ -24,25 +26,12 @@ struct Block : public BlockSize
{
bool allocated; // true if the block is currently allocated
int event_count; // number of outstanding cuda events
std::unordered_set<THCStream *> streams;
std::set<THCStreamPtr> streams;

Block(size_t size, void* ptr, bool allocated) :
BlockSize(size, ptr), allocated(allocated), event_count(0) { }
BlockSize(size, ptr), allocated(allocated), event_count(0), streams() {}
};

struct BlockStreamCleaner {
std::unordered_set<THCStream *> &streams;

BlockStreamCleaner(std::unordered_set<THCStream *> &streams) : streams(streams) {}
~BlockStreamCleaner() {
for(auto it = streams.begin(); it != streams.end(); ++it) {
if (*it != NULL) {
THCStream_free(*it);
}
}
streams.clear();
}
};
static bool BlockComparator(const BlockSize& a, const BlockSize& b)
{
// sort by size, break ties with pointer
Expand Down Expand Up @@ -129,25 +118,12 @@ struct HostAllocator
// we process the streams.
block.allocated = false;

// since the block has been deallocated, no point in keeping around the
// streams, even in case of error.
BlockStreamCleaner sc(block.streams);
for (auto it = block.streams.begin(); it != block.streams.end(); ++it) {
cudaEvent_t event;
err = cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
if (err != cudaSuccess) {
return err;
}

err = cudaEventRecord(event, (*it) == NULL ? NULL : (*it)->stream);
if (err != cudaSuccess) {
return err;
}

// the block will not be re-used until all associated events have occured
block.event_count++;
cuda_events.emplace_back(event, ptr);
// insert CUDA events for each stream on which this block was used. This
err = insertEvents(block);
if (err != cudaSuccess) {
return err;
}

if (block.event_count == 0) {
// the block can be re-used if there are no outstanding cuda events
available.insert(block);
Expand All @@ -168,11 +144,11 @@ struct HostAllocator

Block& block = it->second;
THAssert(block.allocated);
auto res = block.streams.insert(stream);
if (res.second == true && stream != NULL) {
THCStream_retain(stream);
}

THCStreamPtr stream_ptr(stream, &THCStream_free);
THCStream_retain(stream);

block.streams.insert(std::move(stream_ptr));
return cudaSuccess;
}

Expand Down Expand Up @@ -239,6 +215,36 @@ struct HostAllocator
}
}
}

cudaError_t insertEvents(Block& block)
{
cudaError_t err;

int prev_device;
err = cudaGetDevice(&prev_device);
if (err != cudaSuccess) return err;

std::set<THCStreamPtr> streams(std::move(block.streams));
for (auto it = streams.begin(); it != streams.end(); ++it) {
auto& stream = *it;

err = cudaSetDevice(stream->device);
if (err != cudaSuccess) break;

cudaEvent_t event;
err = cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
if (err != cudaSuccess) break;

err = cudaEventRecord(event, stream->stream);
if (err != cudaSuccess) break;

block.event_count++;
cuda_events.emplace_back(event, block.ptr);
}

cudaSetDevice(prev_device);
return err;
}
};

} // namespace
Expand Down
45 changes: 22 additions & 23 deletions THCGeneral.c
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ void THCudaInit(THCState* state)
THCudaCheck(cudaSetDevice(i));
THCudaCheck(cudaGetDeviceProperties(&state->deviceProperties[i], i));

// Allocate space for the NULL stream
// Allocate space for the default stream
res->streams = (THCStream**) malloc(sizeof(THCStream*));
res->streams[0] = NULL;
res->streams[0] = THCStream_defaultStream(i);

/* The scratch space that we want to have available per each device is
based on the number of SMs available per device. We guarantee a
Expand Down Expand Up @@ -158,8 +158,8 @@ void THCudaShutdown(THCState* state)
for (int dev = 0; dev < deviceCount; ++dev) {
THCudaCheck(cudaSetDevice(dev));
THCCudaResourcesPerDevice* res = &(state->resourcesPerDevice[dev]);
/* Free user reserved streams (0 is the default stream) */
for (int i = 1; i <= state->numUserStreams; ++i) {
/* Free all streams */
for (int i = 0; i <= state->numUserStreams; ++i) {
THCStream_free(res->streams[i]);
}
/* Free user defined BLAS handles */
Expand Down Expand Up @@ -428,7 +428,7 @@ cudaStream_t THCState_getDeviceStream(THCState *state, int device, int streamInd
}
THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device);
THCStream* stream = res->streams[streamIndex];
return stream ? stream->stream : NULL;
return stream->stream;
}

cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int handle)
Expand All @@ -444,18 +444,24 @@ cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int han

static THCStream* THCState_getStreamOnDevice(THCState* state, int device)
{
return (THCStream*) THCThreadLocal_get(state->currentStreams[device]);
THCThreadLocal local = state->currentStreams[device];
THCStream* stream = (THCStream*)THCThreadLocal_get(local);
if (!stream) {
stream = THCStream_defaultStream(device);
THCStream_retain(stream);
THCThreadLocal_set(local, stream);
}
return stream;
}

static void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream)
{
if (stream) {
if (stream->device != device) {
THError("invalid stream; expected stream for device %d, but was on %d",
device, stream->device);
}
THCStream_retain(stream);
THAssert(stream);
if (stream->device != device) {
THError("invalid stream; expected stream for device %d, but was on %d",
device, stream->device);
}
THCStream_retain(stream);
THCThreadLocal local = state->currentStreams[device];
THCStream_free((THCStream*)THCThreadLocal_get(local));
THCThreadLocal_set(local, stream);
Expand All @@ -464,7 +470,8 @@ static void THCState_setStreamOnDevice(THCState *state, int device, THCStream *s
cudaStream_t THCState_getCurrentStreamOnDevice(THCState *state, int device)
{
THCStream* stream = THCState_getStreamOnDevice(state, device);
return stream ? stream->stream : NULL;
THAssert(stream);
return stream->stream;
}

cudaStream_t THCState_getCurrentStream(THCState *state)
Expand Down Expand Up @@ -501,9 +508,6 @@ cublasHandle_t THCState_getCurrentBlasHandle(THCState *state)
int THCState_getCurrentStreamIndex(THCState *state)
{
THCStream* stream = THCState_getStream(state);
if (!stream) {
return 0;
}

int device;
THCudaCheck(cudaGetDevice(&device));
Expand Down Expand Up @@ -549,13 +553,8 @@ void THCState_setCurrentStreamIndex(THCState *state, int streamIndex)

int device;
for (device = 0; device < state->numDevices; ++device) {
THCStream* stream = NULL;
if (streamIndex != 0) {
THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device);
stream = res->streams[streamIndex];
}

THCState_setStreamOnDevice(state, device, stream);
THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device);
THCState_setStreamOnDevice(state, device, res->streams[streamIndex]);
}
}

Expand Down
30 changes: 0 additions & 30 deletions THCStream.c

This file was deleted.

60 changes: 60 additions & 0 deletions THCStream.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "THCStream.h"

#include <mutex>
#include <cuda_runtime_api.h>
#include "THAtomic.h"

#define MAX_DEVICES 256
static THCStream default_streams[MAX_DEVICES];

static void initialize_default_streams()
{
for (int i = 0; i < MAX_DEVICES; i++) {
default_streams[i].device = i;
}
}

THCStream* THCStream_new(int flags)
{
THCStream* self = (THCStream*) malloc(sizeof(THCStream));
self->refcount = 1;
THCudaCheck(cudaGetDevice(&self->device));
THCudaCheck(cudaStreamCreateWithFlags(&self->stream, flags));
return self;
}

THC_API THCStream* THCStream_defaultStream(int device)
{
// default streams aren't refcounted
THAssert(device >= 0 && device < MAX_DEVICES);
std::once_flag once;
std::call_once(once, &initialize_default_streams);
return &default_streams[device];
}

THCStream* THCStream_newWithPriority(int flags, int priority)
{
THCStream* self = (THCStream*) malloc(sizeof(THCStream));
self->refcount = 1;
THCudaCheck(cudaGetDevice(&self->device));
THCudaCheck(cudaStreamCreateWithPriority(&self->stream, flags, priority));
return self;
}

void THCStream_free(THCStream* self)
{
if (!self || !self->stream) {
return;
}
if (THAtomicDecrementRef(&self->refcount)) {
THCudaCheckWarn(cudaStreamDestroy(self->stream));
free(self);
}
}

void THCStream_retain(THCStream* self)
{
if (self->stream) {
THAtomicIncrementRef(&self->refcount);
}
}
2 changes: 2 additions & 0 deletions THCStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ struct THCStream


THC_API THCStream* THCStream_new(int flags);
THC_API THCStream* THCStream_defaultStream(int device);
THC_API THCStream* THCStream_newWithPriority(int flags, int priority);
THC_API void THCStream_free(THCStream* self);
THC_API void THCStream_retain(THCStream* self);

Expand Down
4 changes: 2 additions & 2 deletions generic/THCTensorCopy.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void THCTensor_(copyAsyncCPU)(THCState *state, THCTensor *self, struct THTensor
THTensor_(data)(src),
THTensor_(nElement)(src) * sizeof(real),
cudaMemcpyHostToDevice,
stream == NULL ? NULL : stream->stream));
stream->stream));

THCudaCheck(THCCachingHostAllocator_recordEvent(src->storage->data, stream));

Expand Down Expand Up @@ -154,7 +154,7 @@ void THTensor_(copyAsyncCuda)(THCState *state, THTensor *self, struct THCTensor
THCTensor_(data)(state, src),
THCTensor_(nElement)(state, src) * sizeof(real),
cudaMemcpyDeviceToHost,
stream == NULL ? NULL : stream->stream));
stream->stream));

THCudaCheck(THCCachingHostAllocator_recordEvent(src->storage->data, stream));

Expand Down

0 comments on commit 8dfcf7e

Please sign in to comment.