Skip to content

Commit

Permalink
wrap cudaStreamSynchronize calls (pytorch#61889)
Browse files Browse the repository at this point in the history
Summary:
This is a first step towards creating context manager that errors out on synchronizing calls.

Pull Request resolved: pytorch#61889

Reviewed By: albanD

Differential Revision: D29805280

Pulled By: ngimel

fbshipit-source-id: b66400fbe0941b7daa51e6b30abe27b9cccd4e8a
  • Loading branch information
Natalia Gimelshein authored and facebook-github-bot committed Jul 22, 2021
1 parent 3d6aa3a commit 6284d2a
Show file tree
Hide file tree
Showing 20 changed files with 89 additions and 107 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ jobs:
if: always()
run: |
(! git --no-pager grep -I -no $'#include <cub/' -- ./aten ':(exclude)aten/src/ATen/cuda/cub.cuh' || (echo "The above files have direct cub include; please include ATen/cuda/cub.cuh instead and wrap your cub calls in at::native namespace if necessary"; false))
- name: Ensure no raw cuda api calls
if: always()
run: |
(! git --no-pager grep -I -no $'cudaStreamSynchronize' -- ./aten ./c10 ':(exclude)aten/src/ATen/test' ':(exclude)c10/cuda/CUDAFunctions.h' || (echo "The above files call raw cuda APIs directly; please use at::cuda wrappers instead"; false))
clang-format:
runs-on: ubuntu-18.04
Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ quick_checks:
--step 'Ensure no unqualified noqa' \
--step 'Ensure no unqualified type ignore' \
--step 'Ensure no direct cub include' \
--step 'Ensure correct trailing newlines'
--step 'Ensure correct trailing newlines' \
--step 'Ensure no raw cuda api calls'

flake8:
@$(PYTHON) tools/actions_local_runner.py \
Expand Down
18 changes: 1 addition & 17 deletions aten/src/ATen/native/cuda/CUDAScalar.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,19 @@
#include <ATen/NativeFunctions.h>

#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>

#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_version.h>
#endif

namespace at {
namespace native {

Scalar _local_scalar_dense_cuda(const Tensor& self) {
Scalar r;
#if HIP_VERSION >= 301
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
scalar_t value;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(hipMemcpyWithStream(&value, self.data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
r = Scalar(value);
});
#else
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
scalar_t value;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaMemcpyAsync(&value, self.data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
at::cuda::memcpy_and_sync(&value, self.data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);
r = Scalar(value);
});
#endif
return r;
}

Expand Down
7 changes: 1 addition & 6 deletions aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,7 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
void* ptr = (dst_device == kCPU ? dst : src);
AT_CUDA_CHECK(THCCachingHostAllocator_recordEvent(ptr, stream));
} else {
#if HIP_VERSION >= 301
AT_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
#else
AT_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
#endif
at::cuda::memcpy_and_sync(dst, src, nbytes, kind, stream);
}

if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/MiscUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ struct MagmaStreamSyncGuard {
MagmaStreamSyncGuard() {
auto stream = at::cuda::getCurrentCUDAStream();
if (stream != at::cuda::getDefaultCUDAStream()) {
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
at::cuda::stream_synchronize(stream);
}
}

~MagmaStreamSyncGuard() noexcept(false) {
auto default_stream = at::cuda::getDefaultCUDAStream();
if (at::cuda::getCurrentCUDAStream() != default_stream) {
AT_CUDA_CHECK(cudaStreamSynchronize(default_stream));
at::cuda::stream_synchronize(default_stream);
}
}
};
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/native/cuda/Nonzero.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){
auto temp_storage = allocator.allocate(temp_storage_bytes);
cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream);
int num_nonzeros_h;
C10_CUDA_CHECK(cudaMemcpyAsync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream));
//need to synchronize to make sure data is available on the host
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream);
//expected output size is num_nonzeros x ndim
//we are producing output with size {num_nonzeros, ndim} and strides {num_nonzeros, 1} (that is, transposed ndim x num_nonzeros output)
//we are able to directly use passed output with this size and strides, and we can also (per contract)
Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/native/cuda/TensorModeKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ void calculate_mode(

AT_CUDA_CHECK(cudaMemcpyAsync(
values_data, &mode, sizeof(scalar_t), cudaMemcpyHostToDevice, stream));
AT_CUDA_CHECK(cudaMemcpyAsync(
indices_data, &index, sizeof(scalar_t), cudaMemcpyHostToDevice, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
//memcpy_and_sync will synchronize results
at::cuda::memcpy_and_sync(indices_data, &index, sizeof(scalar_t), cudaMemcpyHostToDevice, stream);
}

template <typename scalar_t>
Expand Down
24 changes: 4 additions & 20 deletions aten/src/THC/generic/THCStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
#include <c10/util/intrusive_ptr.h>
#include <c10/util/typeid.h>

#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_version.h>
#endif

scalar_t* THCStorage_(data)(THCState *state, const THCStorage *self)
{
Expand All @@ -26,16 +23,9 @@ void THCStorage_(set)(THCState *state, THCStorage *self, ptrdiff_t index, scalar
2,
"index out of bounds");
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
#if HIP_VERSION >= 301
THCudaCheck(hipMemcpyWithStream(THCStorage_(data)(state, self) + index, &value, sizeof(scalar_t),
cudaMemcpyHostToDevice,
stream));
#else
THCudaCheck(cudaMemcpyAsync(THCStorage_(data)(state, self) + index, &value, sizeof(scalar_t),
at::cuda::memcpy_and_sync(THCStorage_(data)(state, self) + index, &value, sizeof(scalar_t),
cudaMemcpyHostToDevice,
stream));
THCudaCheck(cudaStreamSynchronize(stream));
#endif
stream);
}

scalar_t THCStorage_(get)(THCState *state, const THCStorage *self, ptrdiff_t index)
Expand All @@ -46,14 +36,8 @@ scalar_t THCStorage_(get)(THCState *state, const THCStorage *self, ptrdiff_t ind
"index out of bounds");
scalar_t value;
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
#if HIP_VERSION >= 301
THCudaCheck(hipMemcpyWithStream(&value, THCStorage_(data)(state, self) + index, sizeof(scalar_t),
cudaMemcpyDeviceToHost, stream));
#else
THCudaCheck(cudaMemcpyAsync(&value, THCStorage_(data)(state, self) + index, sizeof(scalar_t),
cudaMemcpyDeviceToHost, stream));
THCudaCheck(cudaStreamSynchronize(stream));
#endif
at::cuda::memcpy_and_sync(&value, THCStorage_(data)(state, self) + index, sizeof(scalar_t),
cudaMemcpyDeviceToHost, stream);
return value;
}

Expand Down
34 changes: 6 additions & 28 deletions aten/src/THC/generic/THCStorageCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,18 @@
#define THC_GENERIC_FILE "THC/generic/THCStorageCopy.cpp"
#else

#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_version.h>
#endif
#include <c10/cuda/CUDAFunctions.h>

void THCStorage_(copyCPU)(THCState *state, THCStorage *self, struct THStorage *src)
{
THArgCheck(self->nbytes() == src->nbytes(), 2, "size does not match");
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
#if HIP_VERSION >= 301
THCudaCheck(hipMemcpyWithStream(
THCStorage_(data)(state, self),
at::cuda::memcpy_and_sync(THCStorage_(data)(state, self),
THStorage_(data)(src),
self->nbytes(),
cudaMemcpyHostToDevice,
stream));
#else
THCudaCheck(cudaMemcpyAsync(
THCStorage_(data)(state, self),
THStorage_(data)(src),
self->nbytes(),
cudaMemcpyHostToDevice,
stream));
THCudaCheck(cudaStreamSynchronize(stream));
#endif
stream);

}

#define TH_CUDA_STORAGE_IMPLEMENT_COPY(TYPEC) \
Expand Down Expand Up @@ -61,22 +49,12 @@ void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *s
{
THArgCheck(self->nbytes() == src->nbytes(), 2, "size does not match");
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
#if HIP_VERSION >= 301
THCudaCheck(hipMemcpyWithStream(
at::cuda::memcpy_and_sync(
THStorage_(data)(self),
THCStorage_(data)(state, src),
self->nbytes(),
cudaMemcpyDeviceToHost,
stream));
#else
THCudaCheck(cudaMemcpyAsync(
THStorage_(data)(self),
THCStorage_(data)(state, src),
self->nbytes(),
cudaMemcpyDeviceToHost,
stream));
THCudaCheck(cudaStreamSynchronize(stream));
#endif
stream);
}

#define TH_CUDA_STORAGE_IMPLEMENT_COPYTO(TYPEC) \
Expand Down
3 changes: 2 additions & 1 deletion c10/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ configure_file(
set(C10_CUDA_SRCS
CUDAStream.cpp
CUDAFunctions.cpp
CUDAMiscFunctions.cpp
CUDACachingAllocator.cpp
impl/CUDAGuardImpl.cpp
impl/CUDATest.cpp
CUDAFunctions.cpp
)
set(C10_CUDA_HEADERS
CUDAException.h
Expand All @@ -34,6 +34,7 @@ set(C10_CUDA_HEADERS
CUDAMathCompat.h
CUDAStream.h
CUDAFunctions.h
CUDAMiscFunctions.h
impl/CUDAGuardImpl.h
impl/CUDATest.h
)
Expand Down
3 changes: 2 additions & 1 deletion c10/cuda/CUDAException.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/CUDAMiscFunctions.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <cuda.h>
Expand Down
16 changes: 0 additions & 16 deletions c10/cuda/CUDAFunctions.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#include <cuda_runtime_api.h>

#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/macros/Macros.h>

Expand Down Expand Up @@ -141,18 +138,5 @@ void device_synchronize() {
C10_CUDA_CHECK(cudaDeviceSynchronize());
}

const char* get_cuda_check_suffix() noexcept {
static char* device_blocking_flag = getenv("CUDA_LAUNCH_BLOCKING");
static bool blocking_enabled =
(device_blocking_flag && atoi(device_blocking_flag));
if (blocking_enabled) {
return "";
} else {
return "\nCUDA kernel errors might be asynchronously reported at some"
" other API call,so the stacktrace below might be incorrect."
"\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.";
}
}

} // namespace cuda
} // namespace c10
25 changes: 24 additions & 1 deletion c10/cuda/CUDAFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
// The naming convention used here matches the naming convention of torch.cuda

#include <c10/core/Device.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAMacros.h>
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_version.h>
#endif
#include <cuda_runtime_api.h>

namespace c10 {
namespace cuda {
Expand All @@ -30,7 +35,25 @@ C10_CUDA_API void set_device(DeviceIndex device);

C10_CUDA_API void device_synchronize();

C10_CUDA_API const char* get_cuda_check_suffix() noexcept;
// the subsequent functions are defined in the header because for performance
// reasons we want them to be inline
C10_CUDA_API void __inline__ memcpy_and_sync(
void* dst,
void* src,
int64_t nbytes,
cudaMemcpyKind kind,
cudaStream_t stream) {
#if defined(HIP_VERSION) && (HIP_VERSION >= 301)
C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
#else
C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
#endif
}

C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
}

} // namespace cuda
} // namespace c10
20 changes: 20 additions & 0 deletions c10/cuda/CUDAMiscFunctions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include <c10/cuda/CUDAMiscFunctions.h>
#include <stdlib.h>

namespace c10 {
namespace cuda {

const char* get_cuda_check_suffix() noexcept {
static char* device_blocking_flag = getenv("CUDA_LAUNCH_BLOCKING");
static bool blocking_enabled =
(device_blocking_flag && atoi(device_blocking_flag));
if (blocking_enabled) {
return "";
} else {
return "\nCUDA kernel errors might be asynchronously reported at some"
" other API call,so the stacktrace below might be incorrect."
"\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.";
}
}
} // namespace cuda
} // namespace c10
11 changes: 11 additions & 0 deletions c10/cuda/CUDAMiscFunctions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once
// this file is to avoid circular dependency between CUDAFunctions.h and
// CUDAExceptions.h

#include <c10/cuda/CUDAMacros.h>

namespace c10 {
namespace cuda {
C10_CUDA_API const char* get_cuda_check_suffix() noexcept;
}
} // namespace c10
5 changes: 2 additions & 3 deletions c10/cuda/CUDAStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

#include <c10/core/DeviceGuard.h>
#include <c10/core/Stream.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/util/Exception.h>

/*
Expand Down Expand Up @@ -128,7 +127,7 @@ class C10_CUDA_API CUDAStream {

void synchronize() const {
DeviceGuard guard{stream_.device()};
C10_CUDA_CHECK(cudaStreamSynchronize(stream()));
c10::cuda::stream_synchronize(stream());
}

int priority() const {
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/CudaIPCTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,15 @@ CudaIPCSentData::CudaIPCSentData(
event_sync_required_ = true;
} else {
auto stream = c10::cuda::getCurrentCUDAStream(device.index());
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
at::cuda::stream_synchronize(stream);
event_ = nullptr;
event_sync_required_ = false;
}
#else
// cuIpcGetEventHandle with HIP is not supported, so we have to sync
// stream instead of passing event
auto stream = c10::cuda::getCurrentCUDAStream(device.index());
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
at::cuda::stream_synchronize(stream);
event_sync_required_ = false;
#endif
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/generic/StorageSharing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)

// TODO: Instead of cudaStreamSynchronize it is possible to add Stream
// Callback and release counter inside of it (need to check performance impact)
cudaStreamSynchronize(c10::cuda::getCurrentCUDAStream(device));
at::cuda::stream_synchronize(c10::cuda::getCurrentCUDAStream(device));

// We don't want to break existing code, so resource deletion is best
// effort basis. Exception expected if producer process terminated
Expand Down
Loading

0 comments on commit 6284d2a

Please sign in to comment.