Skip to content

Commit

Permalink
Revert D14605905: [pytorch][PR] Add return_counts to torch.unique
Browse files Browse the repository at this point in the history
Differential Revision:
D14605905

Original commit changeset: 555f5a12a8e2

fbshipit-source-id: c7874f5987893e956c022180a37763d88bba38db
  • Loading branch information
soumith authored and facebook-github-bot committed Mar 27, 2019
1 parent bdd098c commit 66628f7
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 214 deletions.
49 changes: 17 additions & 32 deletions aten/src/ATen/native/Unique.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ namespace native{
namespace {

template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
std::tuple<Tensor, Tensor> _unique_cpu_template(
const Tensor& self,
const bool sorted,
const bool return_inverse,
const bool return_counts) {
const bool return_inverse) {
const Tensor& input = self.contiguous();
const scalar_t* input_data = input.data<scalar_t>();
std::unordered_set<scalar_t> set(input_data, input_data + input.numel());
Expand All @@ -34,8 +33,7 @@ std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
}

Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
Tensor counts = at::empty({0}, self.options().dtype(kLong));
if (return_inverse || return_counts) {
if (return_inverse) {
inverse_indices.resize_(input.sizes());
int64_t* inverse_indices_data = inverse_indices.data<int64_t>();
std::unordered_map<scalar_t, int64_t> inverse_map;
Expand All @@ -46,29 +44,21 @@ std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
for (int i = 0; i < input.numel(); ++i) {
inverse_indices_data[i] = inverse_map[input_data[i]];
}
if (return_counts) {
counts.resize_(output.sizes());
counts.fill_(0);
for (int i = 0; i < input.numel(); ++i) {
counts[inverse_map[input_data[i]]] += 1;
}
}
}
return std::make_tuple(output, inverse_indices, counts);
return std::make_tuple(output, inverse_indices);
}

template<class ForwardIt>
ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
std::vector<int64_t>& indices, Tensor inverse_indices_vec, Tensor counts) {
std::vector<int64_t>& indices, Tensor inverse_indices_vec) {
if (first == last) {
return last;
}
// save to calculate distance to iterators
ForwardIt begin = first;

// set first inverse index and count
// set first inverse index
inverse_indices_vec[indices[0]] = 0;
counts[0] += 1;

ForwardIt result = first;
while (++first != last) {
Expand All @@ -78,18 +68,16 @@ ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
int64_t idx_result = std::distance(begin, result);
int64_t idx_first = std::distance(begin, first);
inverse_indices_vec[indices[idx_first]] = idx_result;
counts[idx_result] += 1;
}

return ++result;
}

template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse,
const bool return_counts) {
const bool return_inverse) {
// reshape tensor as [dim, -1]
Tensor input_flat = self.transpose(dim, 0);
auto orig_sizes = input_flat.sizes().vec();
Expand Down Expand Up @@ -121,12 +109,10 @@ std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
}

Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong));
Tensor counts = at::zeros(indices.size(), self.options().dtype(kLong));
std::vector<Tensor> input_unbind = at::unbind(input_sorted, 0);
auto last = _unique_dim_cpu_impl(
input_unbind.begin(), input_unbind.end(), indices, inverse_indices, counts);
input_unbind.begin(), input_unbind.end(), indices, inverse_indices);
input_unbind.erase(last, input_unbind.end());
counts = at::narrow(counts, 0, 0, input_unbind.size());

// reshape back
auto output = at::stack(input_unbind, 0);
Expand All @@ -135,23 +121,22 @@ std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
output = output.view(new_sizes);
output = output.transpose(0, dim);

return std::make_tuple(output, inverse_indices, counts);
return std::make_tuple(output, inverse_indices);
}
} // namespace


std::tuple<Tensor, Tensor, Tensor>
_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
return _unique_cpu_template<scalar_t>(self, sorted, return_inverse, return_counts);
std::tuple<Tensor, Tensor>
_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cpu", [&] {
return _unique_cpu_template<scalar_t>(self, sorted, return_inverse);
});
}

std::tuple<Tensor, Tensor, Tensor>
_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
std::tuple<Tensor, Tensor>
_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
// The current implementation using `dim` always sorts due to unhashable tensors
return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse, return_counts);
return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse);
});
}

Expand Down
65 changes: 23 additions & 42 deletions aten/src/ATen/native/cuda/Unique.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ namespace native{
namespace {

template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> _unique_cuda_template(
std::tuple<Tensor, Tensor> _unique_cuda_template(
const Tensor& self,
const bool return_inverse,
const bool return_counts) {
const bool return_inverse) {

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
Expand All @@ -29,7 +28,7 @@ template <typename scalar_t>
int64_t num_inp = input.numel();
const scalar_t* input_data = input.data<scalar_t>();

//sort
//sort & unique
Tensor output = input.clone();
output = output.view(-1);
scalar_t* output_data = output.data<scalar_t>();
Expand All @@ -48,36 +47,21 @@ template <typename scalar_t>
thrust::adjacent_difference(policy, output_data, output_data + num_inp, inv_loc_ptr, [=] __device__ (scalar_t a, scalar_t b) -> int64_t { if (a != b) {return 1;} else { return 0; }});
inv_loc[0] = 0;
thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr);
thrust::scatter(policy, inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
thrust::scatter(policy,inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
inverse_indices.resize_(input.sizes());
}

// unique
Tensor counts = at::empty({0}, self.options().dtype(kLong));
if (!return_counts) {
int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
output.resize_(num_out);
} else {
Tensor sorted_indices = at::arange(0, num_inp + 1, self.type().toScalarType(kLong));
int64_t* sorted_indices_ptr = sorted_indices.data<int64_t>();
int64_t num_out = thrust::unique_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr).first - output_data;
sorted_indices[num_out] = num_inp;
output.resize_(num_out);
counts.resize_(num_out);
int64_t* counts_ptr = counts.data<int64_t>();
thrust::adjacent_difference(policy, sorted_indices_ptr + 1, sorted_indices_ptr + num_out + 1, counts_ptr);
}
int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
output.resize_(num_out);

THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
return std::tuple<Tensor, Tensor>(output, inverse_indices);
}

template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> _unique_dim_cuda_template(
std::tuple<Tensor, Tensor> _unique_dim_cuda_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse,
const bool return_counts) {
const bool return_inverse) {

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
Expand All @@ -89,7 +73,7 @@ template <typename scalar_t>

scalar_t* input_flat_ptr = input_flat.data<scalar_t>();

Tensor indices = at::arange(0, input_flat.size(0), self.options().dtype(kLong));
Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong));
int64_t* indices_ptr = indices.data<int64_t>();
int64_t numel = input_flat.size(1);

Expand All @@ -112,7 +96,7 @@ template <typename scalar_t>

// get unique tensors
scalar_t* input_sorted_ptr = input_sorted.data<scalar_t>();
Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.options().dtype(kLong));
Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong));
int64_t* input_sorted_indices_ptr = input_sorted_indices.data<int64_t>();
auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(),
[=] __device__ (int64_t a, int64_t b) -> bool {
Expand All @@ -134,13 +118,12 @@ template <typename scalar_t>
output = output.view(new_sizes);
output = output.transpose(0, dim);

// calculate inverse indices and counts
Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
Tensor counts = at::zeros(output.size(dim), self.options().dtype(kLong));
if (return_inverse || return_counts) {
// calculate inverse indices
Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
if (return_inverse) {
int64_t size = self.size(dim);
inverse_indices.resize_(size);
Tensor mask = at::empty(input_sorted.size(0), self.options().dtype(kLong));
Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong));
mask[0] = 1;
for (int i = 0; i < input_sorted.size(0) - 1; ++i) {
if (!at::equal(input_sorted[i], input_sorted[i+1])) {
Expand All @@ -153,29 +136,27 @@ template <typename scalar_t>
Tensor imask = at::cumsum(mask, 0) - 1;
for (int i = 0; i < indices.size(0); ++i) {
inverse_indices[indices[i]] = imask[i];
counts[inverse_indices[indices[i]]] += 1;
}
}

THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
return std::tuple<Tensor, Tensor>(output, inverse_indices);
}
} // namespace


std::tuple<Tensor, Tensor, Tensor>
_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
std::tuple<Tensor, Tensor>
_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cuda", [&] {
// The current CUDA implementation of unique always sort due to the
// lack of hashtable implementation in thrust
return _unique_cuda_template<scalar_t>(self, return_inverse, return_counts);
return _unique_cuda_template<scalar_t>(self, return_inverse);
});
}

std::tuple<Tensor, Tensor, Tensor>
_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
std::tuple<Tensor, Tensor>
_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse, return_counts);
return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
});
}

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2339,14 +2339,14 @@
matches_jit_signature: True
variants: method

- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function
dispatch:
CPU: _unique_cpu
CUDA: _unique_cuda

- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function
dispatch:
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
indices = nz.clone();
} else {
Tensor i = nz.narrow(0, 0, sparse_dim);
std::tie(indices, std::ignore, std::ignore) = _unique_dim(i, 1);
std::tie(indices, std::ignore) = _unique_dim(i, 1);
indices = indices.contiguous(); // many sparse CUDA kernels require contiguity, see issue #12633
}

Expand Down
Loading

0 comments on commit 66628f7

Please sign in to comment.