-
Notifications
You must be signed in to change notification settings - Fork 23.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move Stream.query() implementation down to C++ #15737
Conversation
@@ -75,13 +75,7 @@ def query(self): | |||
Returns: | |||
A boolean indicating if all kernels in this stream are completed. | |||
""" | |||
with torch.cuda.device(self.device): | |||
res = cudart().cudaStreamQuery(self) | |||
if res == cudaStatus.ERROR_NOT_READY: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we already have an error checking implementation in C++ that I can call? Or should I just return cudaError_t
and keep the error checking in Python? Or implement it using AT_CHECK
?
torch/csrc/cuda/Stream.cpp
Outdated
@@ -51,6 +60,7 @@ static struct PyMemberDef THCPStream_members[] = { | |||
}; | |||
|
|||
static PyMethodDef THCPStream_methods[] = { | |||
{(char*)"__query__", (PyCFunction)THCPStream_query, METH_NOARGS, nullptr}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the name __query__
appropriate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe should I remove Maybe not, otherwise we wouldn't have Python docs.query()
in streams.py
and directly exposing this _CudaStreamBase
API as query
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add Python docstrings to C functions using torch._C._add_docstr
. For example, see torch/_torch_docs.py
which adds documentation to the torch.xxx
functions.
I think you can do the same thing from torch/cuda/streams.py
to add documentation to functions on the base C class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean it is preferred to completely remove query()
implementation in torch/cuda/streams.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me try the following:
- completely remove
query()
fromtorch/cuda/streams.py
. - Create a new
torch/cuda/_stream_docs.py
file, and use_add_docstr
to addquery()
docs - import
torch/cuda/_stream_docs.py
intorch/cuda/__init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's slightly better to put the documentation near the code in torch/cuda/streams.py
than in a separate torch/cuda/_stream_docs.py
.
The _torch_docs.py
is a special case because there are so many torch functions and torch/init.py is rather large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/csrc/cuda/Stream.cpp
Outdated
static PyObject * THCPStream_query(THCPStream *self) { | ||
HANDLE_TH_ERRORS | ||
|
||
THCPModule_setDevice(self->device); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use deviceGuard, you cannot just change device (context manager in the python implementation restored original device)
torch/csrc/cuda/Stream.cpp
Outdated
HANDLE_TH_ERRORS | ||
|
||
THCPModule_setDevice(self->device); | ||
return PyBool_FromLong(cudaStreamQuery(self->cuda_stream) == cudaSuccess); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THCudaCheck handles cuda errors, however, in your case you have to distringuish between cudaErrorNotReady (ok, return False) and all other cuda errors (something went wrong, error out), e.g. like here
pytorch/aten/src/THC/THCCachingAllocator.cpp
Lines 492 to 497 in 2d485ff
cudaError_t err = cudaEventQuery(event); | |
if (err == cudaErrorNotReady) { | |
break; | |
} else if (err != cudaSuccess) { | |
AT_CUDA_CHECK(err); | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
High-level I think you should use You can get a Natalia's comments are still relevant, they should just be applied to the new at::cuda::CUDAStream::query() function. |
… docs using torch._C._add_docstr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Can you verify that the docs work properly before landing.
It's something like:
# in pytorch dir
cd docs
pip install -r requirements.txt # installs sphinx and stuff
make html # or singlehtml for one page
You then need to run a HTTP webserver to view the HTML. If you're using a IPv6-only devserver you can use the Python script ~sgross/bin/serve
.
@colesbury Thanks! Yes, I checked the generated html files. Looks correct. Let me try serving it using a webserver. updatedocs looks good! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@yf225 HI Will, CI tests on MacOS systems encounter the following error (e.g., this one). Do you know if it is due to changes made in this PR?
|
38052c7
to
7e43975
Compare
@colesbury Hi Sam, could you please help take another look at this PR. I reverted the changes to the point where we define
|
This may be deeper in the rabbit hole than you would like to go, but we don't actually really need torch._C._CudaStreamBase anymore; we could replace it with a pure Python class which just wrapped the uint64_t stream identifier, and call into C++ whenever we want to get the actual cudaStream_t. This would solve your "doesn't work on non-CUDA doc builds" problem. (Though, I suspect that when we do doc builds, it is with CUDA enabled, so this may be a moot point.) |
Thanks @ezyang. I don't mind dig deeper if that design is the appropriate way to go. Could you please elaborate more on the following thought?
Without |
One plausible approach is, anywhere we previously passed in a _CudaStreamBase, we instead just pass the uint64_t. The receiving function unpacks it into a CUDAStream and then can get the cudaStream_t that way. (That's why it's a rabbit hole; all call sites would need to be adjusted.) Maybe @colesbury has other ideas though :> |
I think the easiest thing will be to do something like you were doing in a previous commit: Have a method definition in Python that calls the C++ method. However, they can both be called "query()" -- you don't need to name one class Stream(torch._C._CudaStreamBase):
def query(self):
r"""Python docstring..."""
return super(Stream, self).query() We can figure out how to fix |
Got it, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
The failed test was due to the following error, and passed after rerun.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
See #15682
Pushing up this small PR to check if I am doing the right thing. If correct, more will follow for other Stream APIs. Questions will be added inline.