Skip to content

Commit 25b6444

Browse files
swolchokfacebook-github-bot
authored andcommitted
support Half in minimum and clamp
Summary: IIUC, these ops need to support Half but don't. Noticed it as a difference from maximum. Differential Revision: D56846242
1 parent a87adda commit 25b6444

File tree

5 files changed

+89
-27
lines changed

5 files changed

+89
-27
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ bool is_out_of_bounds(CTYPE_VAL val) {
5353
}
5454
});
5555
} else if (isFloatingType(out_type)) {
56-
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
56+
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
5757
if (std::isfinite(val) &&
5858
is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double>(val)) {
5959
ET_LOG(Error, "%s value out of bounds", val_name);
@@ -119,7 +119,7 @@ Tensor& clamp_out(
119119

120120
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
121121

122-
ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
122+
ET_SWITCH_REALH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
123123
// Extract optional min value
124124
CTYPE_OUT min = 0;
125125
if (has_min) {
@@ -140,7 +140,7 @@ Tensor& clamp_out(
140140
});
141141
}
142142

143-
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() {
143+
ET_SWITCH_REALHB_TYPES(in_type, ctx, "clamp", CTYPE_IN, [&]() {
144144
apply_unary_map_fn(
145145
[has_min, min, has_max, max](const CTYPE_IN val_in) {
146146
CTYPE_OUT val_out = static_cast<CTYPE_OUT>(val_in);
@@ -195,20 +195,20 @@ Tensor& clamp_tensor_out(
195195
ScalarType out_type = out.scalar_type();
196196

197197
if (has_min) {
198-
common_type = promoteTypes(common_type, min_type);
198+
common_type = promoteTypes(common_type, min_type, /*half_to_float*/ true);
199199
}
200200
if (has_max) {
201-
common_type = promoteTypes(common_type, max_type);
201+
common_type = promoteTypes(common_type, max_type, /*half_to_float*/ true);
202202
}
203203

204204
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
205205

206206
constexpr auto name = "clamp.Tensor_out";
207207

208-
ET_SWITCH_REALB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
209-
ET_SWITCH_REALB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
210-
ET_SWITCH_REALB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
211-
ET_SWITCH_REALB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
208+
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
209+
ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
210+
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
211+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
212212
apply_ternary_elementwise_fn<
213213
CTYPE_IN,
214214
CTYPE_MIN,

kernels/portable/cpu/op_minimum.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,25 +80,24 @@ Tensor& minimum_out(
8080

8181
ScalarType a_type = a.scalar_type();
8282
ScalarType b_type = b.scalar_type();
83-
ScalarType common_type = promoteTypes(a_type, b_type);
83+
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
8484
ScalarType out_type = out.scalar_type();
8585

8686
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
8787

88-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "minimum.out", CTYPE_A, [&]() {
89-
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "minimum.out", CTYPE_B, [&]() {
90-
using CTYPE_IN =
91-
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
88+
ET_SWITCH_REALHB_TYPES(a_type, ctx, "minimum.out", CTYPE_A, [&]() {
89+
ET_SWITCH_REALHB_TYPES(b_type, ctx, "minimum.out", CTYPE_B, [&]() {
90+
using CTYPE_IN = typename torch::executor::
91+
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
9292
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
93-
ET_SWITCH_REAL_TYPES_AND(
94-
Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
95-
MinimumInner<
96-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
97-
CTYPE_A,
98-
CTYPE_B,
99-
CTYPE_IN,
100-
CTYPE_OUT>::run(a, b, out);
101-
});
93+
ET_SWITCH_REALHB_TYPES(out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
94+
MinimumInner<
95+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
96+
CTYPE_A,
97+
CTYPE_B,
98+
CTYPE_IN,
99+
CTYPE_OUT>::run(a, b, out);
100+
});
102101
});
103102
});
104103

kernels/portable/cpu/util/math_util.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,48 @@ INT_T max_override(INT_T a, INT_T b) {
9494
return std::max(a, b);
9595
}
9696

97+
template <
98+
typename T,
99+
typename std::enable_if<
100+
std::is_same<T, torch::executor::Half>::value,
101+
bool>::type = true>
102+
T min_override(T a, T b) {
103+
const auto float_a = static_cast<float>(a);
104+
if (std::isnan(float_a)) {
105+
return a;
106+
}
107+
const auto float_b = static_cast<float>(b);
108+
if (std::isnan(float_b)) {
109+
return b;
110+
}
111+
112+
if (float_a < float_b) {
113+
return a;
114+
}
115+
return b;
116+
}
117+
118+
template <
119+
typename T,
120+
typename std::enable_if<
121+
std::is_same<T, torch::executor::Half>::value,
122+
bool>::type = true>
123+
T max_override(T a, T b) {
124+
const auto float_a = static_cast<float>(a);
125+
if (std::isnan(float_a)) {
126+
return a;
127+
}
128+
const auto float_b = static_cast<float>(b);
129+
if (std::isnan(float_b)) {
130+
return b;
131+
}
132+
133+
if (float_a > float_b) {
134+
return a;
135+
}
136+
return b;
137+
}
138+
97139
/**
98140
* There is a slight difference in how std::fmod works compared to how ATen
99141
* determines remainders:

kernels/test/op_clamp_test.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,16 @@ class OpClampOutTest : public OperatorTest {
147147
// Test cases that are compatible with float and double.
148148
template <ScalarType DTYPE>
149149
void run_floating_point_test_cases() {
150-
constexpr auto kInfinity =
151-
std::numeric_limits<typename TensorFactory<DTYPE>::ctype>::infinity();
150+
using ctype = typename TensorFactory<DTYPE>::ctype;
151+
using opt_infinity_type = std::conditional_t<
152+
std::is_same<ctype, exec_aten::Half>::value,
153+
float,
154+
ctype>;
155+
constexpr auto kInfinity = std::numeric_limits<ctype>::infinity();
156+
const auto kOptInfinity =
157+
OptScalar(static_cast<opt_infinity_type>(kInfinity));
158+
const auto kOptMinusInfinity =
159+
OptScalar(static_cast<opt_infinity_type>(-kInfinity));
152160
std::vector<ClampTestCase<DTYPE>> test_cases = {
153161
{
154162
std::string(__func__) + ": Simple negative/positive clamp",
@@ -178,7 +186,7 @@ class OpClampOutTest : public OperatorTest {
178186
std::string(__func__) + ": Infinite min",
179187
{2, 2}, // sizes
180188
{-10.1, -1.1, 1.1, 10.1}, // input_data
181-
OptScalar(-kInfinity), // min
189+
kOptMinusInfinity, // min
182190
OptScalar(5.5), // max
183191
{-10.1, -1.1, 1.1, 5.5}, // expected_data
184192
},
@@ -187,7 +195,7 @@ class OpClampOutTest : public OperatorTest {
187195
{2, 2}, // sizes
188196
{-10.1, -1.1, 1.1, 10.1}, // input_data
189197
OptScalar(-5.5), // min
190-
OptScalar(kInfinity), // max
198+
kOptInfinity, // max
191199
{-5.5, -1.1, 1.1, 10.1}, // expected_data
192200
},
193201
{
@@ -285,6 +293,15 @@ TEST_F(OpClampOutTest, LongTensors) {
285293
run_signed_integer_test_cases<ScalarType::Long>();
286294
}
287295

296+
TEST_F(OpClampOutTest, HalfTensors) {
297+
// Note that the integer test cases test the situation where the min/max value
298+
// Scalars are integer types, demonstrating that floating point types can be
299+
// clamped to integer values.
300+
run_unsigned_integer_test_cases<ScalarType::Half>();
301+
run_signed_integer_test_cases<ScalarType::Half>();
302+
run_floating_point_test_cases<ScalarType::Half>();
303+
}
304+
288305
TEST_F(OpClampOutTest, FloatTensors) {
289306
// Note that the integer test cases test the situation where the min/max value
290307
// Scalars are integer types, demonstrating that floating point types can be

kernels/test/op_minimum_test.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ TEST_F(OpMinimumOutTest, LongTensors) {
6565
test_minimum_out_same_size<ScalarType::Long>();
6666
}
6767

68+
TEST_F(OpMinimumOutTest, HalfTensors) {
69+
test_minimum_out_same_size<ScalarType::Half>();
70+
}
71+
6872
TEST_F(OpMinimumOutTest, FloatTensors) {
6973
test_minimum_out_same_size<ScalarType::Float>();
7074
}

0 commit comments

Comments
 (0)