forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRepeat.cpp
100 lines (89 loc) · 2.52 KB
/
Repeat.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/native/Repeat.h>
#include <c10/util/irange.h>
template <typename index_t>
static void compute_cpu(
index_t* repeat_ptr,
int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {
TORCH_CHECK(
(result_size == cumsum_ptr[size - 1]),
"allocated size does not match required size");
at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) {
for (const auto i : c10::irange(i_begin, i_end)) {
int64_t end = cumsum_ptr[i];
index_t size = repeat_ptr[i];
TORCH_CHECK((size >= 0), "repeats can not be negative");
int64_t start = end - size;
for (const auto j : c10::irange(start, end)) {
result_ptr[j] = i;
}
}
});
}
namespace at {
namespace native {
Tensor repeat_interleave_cpu(
const Tensor& repeat,
c10::optional<int64_t> output_size) {
Tensor output;
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_cpu", [&]() {
output = repeat_interleave_common<index_t, compute_cpu<index_t>>(
repeat, output_size);
});
return output;
}
Tensor repeat_interleave(
const Tensor& self,
const Tensor& repeats,
c10::optional<int64_t> dim,
c10::optional<int64_t> output_size) {
Tensor input = self;
// Store conj and neg bits
const auto conj = input.is_conj();
if (conj) {
input = input.conj();
}
const auto neg = input.is_neg();
if (neg) {
input = input._neg_view();
}
if (!dim) {
input = input.flatten();
dim = 0;
}
Tensor repeats_ = repeats;
if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.size(0) == 1)) {
repeats_ = repeats.reshape({1}).expand({input.size(dim.value())});
} else if (repeats.dim() == 1) {
TORCH_CHECK(
repeats.size(0) == input.size(dim.value()),
"repeats must have the same size as input along dim")
} else {
AT_ERROR("repeats must be 0-dim or 1-dim tensor");
}
auto ret = input.index_select(
dim.value(), at::repeat_interleave(repeats_, output_size));
// Restore conj and neg bits
if (conj) {
ret = ret.conj();
}
if (neg) {
ret = ret._neg_view();
}
return ret;
}
Tensor repeat_interleave(
const Tensor& self,
int64_t repeats,
c10::optional<int64_t> dim,
c10::optional<int64_t> output_size) {
at::Tensor repeats_ =
at::empty(1, self.options().dtype(at::kLong)).fill_(repeats);
return at::native::repeat_interleave(self, repeats_, dim, output_size);
}
} // namespace native
} // namespace at