forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTriangularOps.cpp
182 lines (157 loc) · 5.25 KB
/
TriangularOps.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#include <ATen/ATen.h>
#include <ATen/CPUApplyUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/TensorMeta.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TriangularOpsUtils.h>
#include <c10/util/irange.h>
namespace at {
namespace meta {
TORCH_META_FUNC(tril)(const Tensor& self, int64_t k) {
set_output(self.sizes(), self.options());
}
TORCH_META_FUNC(triu)(const Tensor& self, int64_t k) {
set_output(self.sizes(), self.options());
}
} // namespace meta
namespace native {
namespace {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t>
void apply_triu_tril_single(
scalar_t* result,
scalar_t* self,
bool inplace,
int64_t k,
int64_t n,
int64_t m,
int64_t res_row_stride,
int64_t res_col_stride,
int64_t self_row_stride,
int64_t self_col_stride,
bool upper) {
constexpr int64_t zero = 0;
if (upper) {
parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
for (int64_t i : c10::irange(start, end)) {
for (int64_t j = 0; j < std::min(m, i + k); j++) {
result[i * res_row_stride + j * res_col_stride] = 0;
}
if (!inplace) { // copy the rest of the self if not inplace
for (int64_t j = std::max(zero, i + k); j < m; j++) {
result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
}
}
}
});
} else {
parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
for (int64_t i : c10::irange(start, end)) {
for (int64_t j = std::max(zero, i + k + 1); j < m; j++) {
result[i * res_row_stride + j * res_col_stride] = 0;
}
if (!inplace) { // copy the rest of the self if not inplace
for (int64_t j = zero; j < std::min(m, i + k + 1); j++) {
result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
}
}
}
});
}
}
template <typename scalar_t>
void apply_triu_tril(const Tensor& result, const Tensor& self, bool inplace, int64_t k, bool upper) {
auto n = self.size(-2);
auto m = self.size(-1);
auto self_data = self.data_ptr<scalar_t>();
auto self_stride = (self.dim() > 2 && self.stride(-3) > 0) ? self.stride(-3) : 1;
auto batchsize = batchCountTrilTriu(result);
auto self_row_stride = self.stride(-2);
auto self_col_stride = self.stride(-1);
auto result_data = result.data_ptr<scalar_t>();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t result_stride, result_row_stride, result_col_stride;
if (result_data != self_data) {
result_stride = (result.dim() > 2 && result.stride(-3) > 0) ? result.stride(-3) : 1;
result_row_stride = result.stride(-2);
result_col_stride = result.stride(-1);
} else {
result_stride = self_stride;
result_row_stride = self_row_stride;
result_col_stride = self_col_stride;
}
parallel_for(0, batchsize, 0, [&](int64_t start, int64_t end) {
for (const auto b : c10::irange(start, end)) {
scalar_t* self_batch = &self_data[b * self_stride];
scalar_t* result_batch = &result_data[b * result_stride];
apply_triu_tril_single<scalar_t>(
result_batch,
self_batch,
inplace,
k,
n,
m,
result_row_stride,
result_col_stride,
self_row_stride,
self_col_stride,
upper);
}
});
}
struct UpperTriangle {
static constexpr const char* op_name = "triu";
static constexpr bool upper = true;
};
struct LowerTriangle {
static constexpr const char *op_name = "tril";
static constexpr bool upper = false;
};
template <typename Triangle>
void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
if (self.numel() == 0) {
return;
}
bool inplace_op = self.is_same(result);
bool inplace_update = false;
Tensor self_c;
std::tie(inplace_update, self_c) = checkTrilTriuBatchContiguous(self, inplace_op);
Tensor result_c;
if (inplace_op && !inplace_update) {
result_c = at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
result_c = result;
}
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
ScalarType::BFloat16,
ScalarType::Half,
ScalarType::Bool,
self.scalar_type(),
Triangle::op_name,
[&]{
apply_triu_tril<scalar_t>(result_c, self_c, inplace_op && inplace_update, k, Triangle::upper);
});
if (inplace_op && !inplace_update) {
result.copy_(result_c);
}
}
} // namespace
TORCH_IMPL_FUNC(tril_cpu)(const Tensor& self, int64_t k, const Tensor &result) {
compute_triu_tril<LowerTriangle>(self, k, result);
}
TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) {
compute_triu_tril<UpperTriangle>(self, k, result);
}
Tensor trace_backward(const Tensor& grad, IntArrayRef sizes) {
if (sizes.size() != 2) {
throw std::runtime_error("expected matrix input");
}
auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options());
auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
grad_input.index_fill_(0, indices, grad);
return grad_input.view(sizes);
}
} // namespace native
} // namespace at