forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathPow.cpp
133 lines (106 loc) · 5.61 KB
/
Pow.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include <ATen/native/Pow.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/ScalarOps.h>
#include <ATen/native/Resize.h>
namespace at {
namespace meta {
TORCH_META_FUNC2(pow, Tensor_Tensor) (const Tensor& base, const Tensor& exp) {
build_borrowing_binary_op(maybe_get_output(), base, exp);
}
TORCH_META_FUNC2(pow, Tensor_Scalar) (const Tensor& base, const Scalar& exp) {
// Numpy compatibility check:
TORCH_CHECK(!(isIntegralType(base.scalar_type(), true) &&
exp.isIntegral(true) && exp.toLong() < 0),
"Integers to negative integer powers are not allowed.");
auto common_dtype = at::result_type(base, exp);
build_output_borrowing_argument_owning_unary_op(maybe_get_output(), base.to(common_dtype));
}
TORCH_META_FUNC2(pow, Scalar) (const Scalar& base, const Tensor& exp) {
// This overload doesn't directly use TensorIterator. It attempts to short-circuit,
// but otherwise redispatches to the Tensor_Tensor overload.
auto dtype = maybe_get_output().defined() ? maybe_get_output().scalar_type() : at::result_type(base, exp);
set_output(0, exp.sizes(), {}, exp.options().dtype(dtype), exp.has_names() ? exp.names() : ArrayRef<Dimname>());
}
} // namespace meta
namespace native {
DEFINE_DISPATCH(pow_tensor_tensor_stub);
DEFINE_DISPATCH(pow_tensor_scalar_stub);
TORCH_IMPL_FUNC(pow_Tensor_Tensor_out) (const Tensor& base, const Tensor& exp, const Tensor& out) {
pow_tensor_tensor_stub(device_type(), *this);
}
TORCH_IMPL_FUNC(pow_Tensor_Scalar_out) (const Tensor& base, const Scalar& exp, const Tensor& out) {
if (exp.equal(0.0)) {
out.fill_(1);
} else if (exp.equal(1.0)) {
out.copy_(base);
} else {
pow_tensor_scalar_stub(device_type(), *this, exp);
}
}
TORCH_IMPL_FUNC(pow_Scalar_out) (const Scalar& base, const Tensor& exp, const Tensor& out) {
if (base.equal(1.0)) {
out.fill_(1);
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
at::pow_out(const_cast<Tensor&>(out), wrapped_scalar_tensor(base, exp.device()), exp); // redispatch!
}
}
Tensor& float_power_out(const Tensor& base, const Tensor& exp, Tensor& result) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ?
at::kComplexDouble : at::kDouble;
TORCH_CHECK(result.scalar_type() == dtype,
"the output given to float_power has dtype ", result.scalar_type(),
" but the operation's result requires dtype ", dtype);
return at::pow_out(result, base.to(dtype), exp.to(dtype));
}
Tensor& float_power_out(const Tensor& base, const Scalar& exp, Tensor& result) {
auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
TORCH_CHECK(result.scalar_type() == dtype,
"the output given to float_power has dtype ", result.scalar_type(),
" but the operation's result requires dtype ", dtype);
// Note: need the casts inside the ternary because conversion functions return e.g. c10::complex,
// which causes a complex scalar to always be returned.
auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
return at::pow_out(result, base.to(dtype), casted_exp);
}
Tensor& float_power_out(const Scalar& base, const Tensor& exp, Tensor& result) {
auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble;
TORCH_CHECK(result.scalar_type() == dtype,
"the output given to float_power has dtype ", result.scalar_type(),
" but the operation's result requires dtype ", dtype);
auto casted_base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
return at::pow_out(result, casted_base, exp.to(dtype));
}
Tensor float_power(const Tensor& base, const Scalar& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
return at::pow(base.to(dtype), casted_exp);
}
Tensor float_power(const Scalar& base, const Tensor& exp) {
auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble;
auto casted_base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
return at::pow(casted_base, exp.to(dtype));
}
Tensor float_power(const Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
return at::pow(base.to(dtype), exp.to(dtype));
}
Tensor& float_power_(Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
TORCH_CHECK(base.scalar_type() == dtype,
"the base given to float_power_ has dtype ", base.scalar_type(),
" but the operation's result requires dtype ", dtype);
return base.pow_(exp.to(dtype));
}
Tensor& float_power_(Tensor& base, const Scalar& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
TORCH_CHECK(base.scalar_type() == dtype,
"the base given to float_power_ has dtype ", base.scalar_type(),
" but the operation's result requires dtype ", dtype);
auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
return base.pow_(casted_exp);
}
} // namespace native
} // namespace at