|
| 1 | +#include <ATen/ATen.h> |
| 2 | +#include <ATen/cuda/CUDAContext.h> |
| 3 | +#include <c10/cuda/CUDACachingAllocator.h> |
| 4 | +#include <ATen/cuda/detail/KernelUtils.h> |
| 5 | +#include <ATen/cuda/detail/OffsetCalculator.cuh> //for MAX_DIMS |
| 6 | +#include <ATen/cuda/CubUtils.cuh> |
| 7 | + |
| 8 | + |
| 9 | +namespace at { |
| 10 | +namespace native { |
| 11 | + |
| 12 | +namespace{ |
| 13 | +template<typename T> |
| 14 | +struct NonZeroOp |
| 15 | +{ |
| 16 | + __host__ __device__ __forceinline__ bool operator()(const T& a) const { |
| 17 | + return (a!=T(0)); |
| 18 | + } |
| 19 | +}; |
| 20 | + |
| 21 | +//TODO: actually support int64_t index_t |
| 22 | +template<typename index_t> |
| 23 | +struct TensorDims { |
| 24 | + index_t sizes[MAX_DIMS]; |
| 25 | +}; |
| 26 | + |
| 27 | +template<typename index_t> |
| 28 | +__global__ void write_indices(int64_t * inp, TensorDims<index_t> dims, int ndim, index_t n){ |
| 29 | + CUDA_KERNEL_LOOP(index, n) { // this assumed int (not int64_t) index |
| 30 | + index_t div = 1; |
| 31 | + int64_t idx_flat = inp[index]; |
| 32 | + for (int dim = ndim-1; dim >= 0; dim--){ |
| 33 | + auto dim_size = dims.sizes[dim]; |
| 34 | + inp[index + dim*n] = (idx_flat/div) % dim_size; |
| 35 | + div *= dim_size; |
| 36 | + } |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | + |
| 41 | +} //anonymous namespace |
| 42 | + |
| 43 | +template<typename scalar_t> |
| 44 | +void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){ |
| 45 | + Tensor self_ = self.contiguous(); |
| 46 | + int N = self_.numel(); |
| 47 | + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 48 | +// compute number of nonzero elements |
| 49 | + size_t temp_storage_bytes=0; |
| 50 | + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); |
| 51 | + auto num_nonzeros = allocator.allocate(sizeof(int)); |
| 52 | + cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, scalar_t*> itr(self_.data_ptr<scalar_t>(), NonZeroOp<scalar_t>()); |
| 53 | + cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); |
| 54 | + auto temp_storage = allocator.allocate(temp_storage_bytes); |
| 55 | + cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); |
| 56 | + int num_nonzeros_h; |
| 57 | + C10_CUDA_CHECK(cudaMemcpyAsync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream)); |
| 58 | + //need to synchronize to make sure data is available on the host |
| 59 | + C10_CUDA_CHECK(cudaStreamSynchronize(stream)); |
| 60 | + //expected output size is num_nonzeros x ndim |
| 61 | + //we are producing output with size {num_nonzeros, ndim} and strides {num_nonzeros, 1} (that is, transposed ndim x num_nonzeros output) |
| 62 | + //we are able to directly use passed output with this size and strides, and we can also (per contract) |
| 63 | + //resize passed output with incorrect sizes anyway we want. |
| 64 | + //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. |
| 65 | + bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); |
| 66 | + at::Tensor out_temp = need_to_copy ? |
| 67 | + at::native::empty_cuda({self.dim(), num_nonzeros_h}, optTypeMetaToScalarType(out.options().dtype_opt()), |
| 68 | + out.options().layout_opt(), out.options().device_opt(), out.options().pinned_memory_opt()) : |
| 69 | + out.resize_({self.dim(), num_nonzeros_h}); |
| 70 | + //Scalars are expected to produce output of size (1,0), so we can't write to it |
| 71 | + if (self.dim() > 0) { |
| 72 | + cub::CountingInputIterator<int64_t> counting_itr(0); |
| 73 | + temp_storage_bytes = 0; |
| 74 | + cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr, |
| 75 | + out_temp.data_ptr<int64_t>(), (int*)num_nonzeros.get(), N, stream); |
| 76 | + temp_storage = allocator.allocate(temp_storage_bytes); |
| 77 | + cub::DeviceSelect::Flagged(temp_storage.get(), temp_storage_bytes, counting_itr, itr, |
| 78 | + out_temp.data_ptr<int64_t>(), (int*)num_nonzeros.get(), N, stream); |
| 79 | + if (num_nonzeros_h > 0 && self.dim() > 1){ |
| 80 | + TensorDims<int> dims; |
| 81 | + for (int i=0; i<self.dim(); i++){ |
| 82 | + dims.sizes[i] = self.sizes()[i]; |
| 83 | + } |
| 84 | + const int nthreads = 256; |
| 85 | + const int nblocks = (num_nonzeros_h + nthreads -1)/nthreads; |
| 86 | + write_indices<<<nblocks, nthreads, 0, stream>>>(out_temp.data_ptr<int64_t>(), |
| 87 | + dims, self.dim(), num_nonzeros_h); |
| 88 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 89 | + } |
| 90 | + } |
| 91 | + if (need_to_copy) { |
| 92 | + out.copy_(out_temp.t()); |
| 93 | + } else { |
| 94 | + //transpose out so it is correct size |
| 95 | + Tensor out_ = out_temp.t(); |
| 96 | + out.set_(out_); |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +Tensor& nonzero_out_cuda(Tensor& out, const Tensor& self){ |
| 101 | + TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \ |
| 102 | + file a support request"); |
| 103 | + TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype()); |
| 104 | + TORCH_CHECK(self.device() == out.device(), "expected self and out to be on the same device, but got out on ", |
| 105 | + out.device(), " and self on ", self.device()); |
| 106 | + TORCH_CHECK(self.dim() <= MAX_DIMS, "nonzero is not supported for tensor with more than ", MAX_DIMS, " dimensions"); |
| 107 | + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, |
| 108 | + self.scalar_type(), "nonzero_cuda", |
| 109 | + [&] {nonzero_cuda_out_impl<scalar_t>(self, out);}); |
| 110 | + return out; |
| 111 | +} |
| 112 | + |
| 113 | +Tensor nonzero_cuda(const Tensor& self){ |
| 114 | + Tensor out = at::native::empty_cuda({0}, kLong, self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt()); |
| 115 | + return nonzero_out_cuda(out, self); |
| 116 | +} |
| 117 | +} //namespace::native |
| 118 | +} //namespace::at |
0 commit comments