Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Stream.query() implementation down to C++ #15737

Closed
wants to merge 10 commits into from

Conversation

mrshenli
Copy link
Contributor

@mrshenli mrshenli commented Jan 4, 2019

See #15682

Pushing up this small PR to check if I am doing the right thing. If correct, more will follow for other Stream APIs. Questions will be added inline.

@mrshenli mrshenli requested a review from colesbury January 4, 2019 17:29
@@ -75,13 +75,7 @@ def query(self):
Returns:
A boolean indicating if all kernels in this stream are completed.
"""
with torch.cuda.device(self.device):
res = cudart().cudaStreamQuery(self)
if res == cudaStatus.ERROR_NOT_READY:
Copy link
Contributor Author

@mrshenli mrshenli Jan 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we already have an error checking implementation in C++ that I can call? Or should I just return cudaError_t and keep the error checking in Python? Or implement it using AT_CHECK?

@@ -51,6 +60,7 @@ static struct PyMemberDef THCPStream_members[] = {
};

static PyMethodDef THCPStream_methods[] = {
{(char*)"__query__", (PyCFunction)THCPStream_query, METH_NOARGS, nullptr},
Copy link
Contributor Author

@mrshenli mrshenli Jan 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the name __query__ appropriate?

Copy link
Contributor Author

@mrshenli mrshenli Jan 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe should I remove query() in streams.py and directly exposing this _CudaStreamBase API as query? Maybe not, otherwise we wouldn't have Python docs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add Python docstrings to C functions using torch._C._add_docstr. For example, see torch/_torch_docs.py which adds documentation to the torch.xxx functions.

I think you can do the same thing from torch/cuda/streams.py to add documentation to functions on the base C class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean it is preferred to completely remove query() implementation in torch/cuda/streams.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try the following:

  1. completely remove query() from torch/cuda/streams.py.
  2. Create a new torch/cuda/_stream_docs.py file, and use _add_docstr to add query() docs
  3. import torch/cuda/_stream_docs.py in torch/cuda/__init__.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's slightly better to put the documentation near the code in torch/cuda/streams.py than in a separate torch/cuda/_stream_docs.py.

The _torch_docs.py is a special case because there are so many torch functions and torch/init.py is rather large.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

static PyObject * THCPStream_query(THCPStream *self) {
HANDLE_TH_ERRORS

THCPModule_setDevice(self->device);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use deviceGuard, you cannot just change device (context manager in the python implementation restored original device)

HANDLE_TH_ERRORS

THCPModule_setDevice(self->device);
return PyBool_FromLong(cudaStreamQuery(self->cuda_stream) == cudaSuccess);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THCudaCheck handles cuda errors, however, in your case you have to distringuish between cudaErrorNotReady (ok, return False) and all other cuda errors (something went wrong, error out), e.g. like here

cudaError_t err = cudaEventQuery(event);
if (err == cudaErrorNotReady) {
break;
} else if (err != cudaSuccess) {
AT_CUDA_CHECK(err);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@colesbury
Copy link
Member

High-level I think you should use at::cuda::CUDAStream and put the body of the query() in that class.

You can get a at::cuda::CUDAStream stream from self->cdata via at::cuda::CUDAStream::unpack(cdata)

Natalia's comments are still relevant, they should just be applied to the new at::cuda::CUDAStream::query() function.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Can you verify that the docs work properly before landing.

It's something like:

# in pytorch dir
cd docs
pip install -r requirements.txt  # installs sphinx and stuff
make html  # or singlehtml for one page

You then need to run a HTTP webserver to view the HTML. If you're using a IPv6-only devserver you can use the Python script ~sgross/bin/serve.

@mrshenli
Copy link
Contributor Author

mrshenli commented Jan 4, 2019

@colesbury Thanks!

Yes, I checked the generated html files. Looks correct. Let me try serving it using a webserver.

update

docs looks good!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mrshenli
Copy link
Contributor Author

mrshenli commented Jan 6, 2019

@yf225 HI Will, CI tests on MacOS systems encounter the following error (e.g., this one). Do you know if it is due to changes made in this PR?

Error: Cannot install moreutils because conflicting formulae are installed.
  parallel: because Both install a `parallel` executable.

Please `brew unlink parallel` before continuing.

@mrshenli mrshenli force-pushed the streamapi branch 2 times, most recently from 38052c7 to 7e43975 Compare January 6, 2019 20:13
@mrshenli
Copy link
Contributor Author

mrshenli commented Jan 7, 2019

@colesbury Hi Sam, could you please help take another look at this PR. I reverted the changes to the point where we define query in torch/cuda/streams.py and call _query(). Otherwise, it would hit the following error on platforms without CUDA. It is because torch._C._CudaStreamBase is assigned a dummy type during initialization, and hence there is no query attr. (BTW, the naming mismatch in the log is because _dummy_type should use name instead of storage_name). I thought about only run _add_docstr(_CudaStreamBase.query, ...) when query is defined, but then query docs cannot be generated on non-cuda systems. I also tried adding query attr to the dummy type, which does not work either because THPModule_addDocStr will check the type of the attribute. Is this an appropriate solution, or any suggestions?

Jan 06 17:50:16 __________________ ERROR collecting test/onnx/test_models.py ___________________
Jan 06 17:50:16 workspace/test/onnx/test_models.py:1: in <module>
Jan 06 17:50:16     from torchvision.models.alexnet import alexnet
Jan 06 17:50:16 .local/lib/python2.7/site-packages/torchvision/__init__.py:1: in <module>
Jan 06 17:50:16     from torchvision import models
Jan 06 17:50:16 .local/lib/python2.7/site-packages/torchvision/models/__init__.py:1: in <module>
Jan 06 17:50:16     from .alexnet import *
Jan 06 17:50:16 .local/lib/python2.7/site-packages/torchvision/models/alexnet.py:1: in <module>
Jan 06 17:50:16     import torch.nn as nn
Jan 06 17:50:16 .local/lib/python2.7/site-packages/torch/__init__.py:250: in <module>
Jan 06 17:50:16     _C._initExtension(manager_path())
Jan 06 17:50:16 .local/lib/python2.7/site-packages/torch/cuda/__init__.py:554: in <module>
Jan 06 17:50:16     from .streams import Stream, Event
Jan 06 17:50:16 .local/lib/python2.7/site-packages/torch/cuda/streams.py:8: in <module>
Jan 06 17:50:16     class Stream(_CudaStreamBase):
Jan 06 17:50:16 .local/lib/python2.7/site-packages/torch/cuda/streams.py:73: in Stream
Jan 06 17:50:16     query = _add_docstr(_CudaStreamBase.query, r"""
Jan 06 17:50:16 E   AttributeError: type object 'CudaHalfStorageBase' has no attribute 'query'

@ezyang
Copy link
Contributor

ezyang commented Jan 7, 2019

This may be deeper in the rabbit hole than you would like to go, but we don't actually really need torch._C._CudaStreamBase anymore; we could replace it with a pure Python class which just wrapped the uint64_t stream identifier, and call into C++ whenever we want to get the actual cudaStream_t. This would solve your "doesn't work on non-CUDA doc builds" problem. (Though, I suspect that when we do doc builds, it is with CUDA enabled, so this may be a moot point.)

@mrshenli
Copy link
Contributor Author

mrshenli commented Jan 7, 2019

Thanks @ezyang. I don't mind dig deeper if that design is the appropriate way to go. Could you please elaborate more on the following thought?

replace it with a pure Python class which just wrapped the uint64_t stream identifier, and call into C++ whenever we want to get the actual cudaStream_t

Without _CudaStreamBase, where should we put the implementations of those C++ functions? Directly call CUDAStream?

@ezyang
Copy link
Contributor

ezyang commented Jan 7, 2019

One plausible approach is, anywhere we previously passed in a _CudaStreamBase, we instead just pass the uint64_t. The receiving function unpacks it into a CUDAStream and then can get the cudaStream_t that way. (That's why it's a rabbit hole; all call sites would need to be adjusted.) Maybe @colesbury has other ideas though :>

@colesbury
Copy link
Member

I think the easiest thing will be to do something like you were doing in a previous commit: Have a method definition in Python that calls the C++ method. However, they can both be called "query()" -- you don't need to name one __query()__. Roughly (in Python):

class Stream(torch._C._CudaStreamBase):
   def query(self):
       r"""Python docstring..."""
       return super(Stream, self).query()

We can figure out how to fix _add_docstr or the dummy class later.

@mrshenli
Copy link
Contributor Author

mrshenli commented Jan 7, 2019

Got it, thanks!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mrshenli
Copy link
Contributor Author

mrshenli commented Jan 7, 2019

The failed test was due to the following error, and passed after rerun.

21:05:50 ../../../../usr/local/caffe2/lib/python2.7/site-packages/caffe2/python/hypothesis_test.py::TestOperators::test_dag_net_forking ./.jenkins/caffe2/test.sh: line 96:  4908 Aborted                 (core dumped) "$PYTHON" -m pytest -x -v --disable-warnings --junit-xml="$pytest_reports_dir/result.xml" --ignore "$CAFFE2_PYPATH/python/test/executor_test.py" --ignore "$CAFFE2_PYPATH/python/operator_test/matmul_op_test.py" --ignore "$CAFFE2_PYPATH/python/operator_test/pack_ops_test.py" --ignore "$CAFFE2_PYPATH/python/mkl/mkl_sbn_speed_test.py" ${rocm_ignore_test[@]} "$CAFFE2_PYPATH/python" "${EXTRA_TESTS[@]}"

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants