Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Stream.query() implementation down to C++ #15737

Closed
wants to merge 10 commits into from
15 changes: 15 additions & 0 deletions c10/cuda/CUDAStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

#include <cuda_runtime_api.h>

#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/DeviceGuard.h>
#include <c10/util/Exception.h>
#include <c10/Stream.h>

Expand Down Expand Up @@ -99,6 +101,19 @@ class C10_CUDA_API CUDAStream {
/// Return the stream ID corresponding to this particular stream.
StreamId id() const { return stream_.id(); }

bool query() const {
DeviceGuard device_guard{stream_.device()};
cudaError_t err = cudaStreamQuery(stream());

if (err == cudaErrorNotReady) {
return false;
} else if (err != cudaSuccess) {
C10_CUDA_CHECK(err);
}

return true;
}

/// Explicit conversion to cudaStream_t.
cudaStream_t stream() const;

Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/cuda/Stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/THP.h>
#include <torch/csrc/cuda/Module.h>

#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>

#include <structmember.h>
Expand Down Expand Up @@ -43,6 +44,12 @@ static PyObject * THCPStream_pynew(PyTypeObject *type, PyObject *args, PyObject
END_HANDLE_TH_ERRORS
}

static PyObject * THCPStream_query(THCPStream *self) {
HANDLE_TH_ERRORS
return PyBool_FromLong(at::cuda::CUDAStream::unpack(self->cdata).query());
END_HANDLE_TH_ERRORS
}

static struct PyMemberDef THCPStream_members[] = {
{(char*)"_cdata", T_ULONGLONG, offsetof(THCPStream, cdata), READONLY, nullptr},
{(char*)"device", T_INT, offsetof(THCPStream, device), READONLY, nullptr},
Expand All @@ -51,6 +58,7 @@ static struct PyMemberDef THCPStream_members[] = {
};

static PyMethodDef THCPStream_methods[] = {
{(char*)"query", (PyCFunction)THCPStream_query, METH_NOARGS, nullptr},
{nullptr}
};

Expand Down
2 changes: 1 addition & 1 deletion torch/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def init_err(self):
class_name = self.__class__.__name__
raise RuntimeError(
"Tried to instantiate dummy base class {}".format(class_name))
return type(storage_name, (object,), {"__init__": init_err})
return type(name, (object,), {"__init__": init_err})


if not hasattr(torch._C, 'CudaDoubleStorageBase'):
Expand Down
12 changes: 3 additions & 9 deletions torch/cuda/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from . import cudart, check_error, cudaStatus
from ._utils import _get_device_index
from torch._C import _add_docstr


class Stream(torch._C._CudaStreamBase):
Expand Down Expand Up @@ -73,15 +74,8 @@ def query(self):
r"""Checks if all the work submitted has been completed.

Returns:
A boolean indicating if all kernels in this stream are completed.
"""
with torch.cuda.device(self.device):
res = cudart().cudaStreamQuery(self)
if res == cudaStatus.ERROR_NOT_READY:
Copy link
Contributor Author

@mrshenli mrshenli Jan 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we already have an error checking implementation in C++ that I can call? Or should I just return cudaError_t and keep the error checking in Python? Or implement it using AT_CHECK?

return False
check_error(res)
return True
return False
A boolean indicating if all kernels in this stream are completed."""
return super(Stream, self).query()

def synchronize(self):
r"""Wait for all the kernels in this stream to complete.
Expand Down