Skip to content

Commit e896b63

Browse files
Use diagonal ResultType for bitwise shifts
1 parent 7848e19 commit e896b63

File tree

3 files changed

+102
-38
lines changed

3 files changed

+102
-38
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,30 +52,37 @@ namespace tu_ns = dpctl::tensor::type_utils;
5252
template <typename argT1, typename argT2, typename resT>
5353
struct BitwiseLeftShiftFunctor
5454
{
55-
static_assert(std::is_same_v<resT, argT1>);
5655
static_assert(std::is_integral_v<argT1>);
5756
static_assert(std::is_integral_v<argT2>);
57+
static_assert(!std::is_same_v<argT1, bool>);
58+
static_assert(!std::is_same_v<argT2, bool>);
5859

5960
using supports_sg_loadstore = typename std::true_type;
6061
using supports_vec = typename std::true_type;
6162

6263
resT operator()(const argT1 &in1, const argT2 &in2)
6364
{
64-
return (in1 << in2);
65+
if constexpr (std::is_unsigned_v<argT2>) {
66+
return (in1 << in2);
67+
}
68+
else {
69+
return (in2 < argT2(0)) ? resT(0) : (in1 << in2);
70+
}
6571
}
6672

6773
template <int vec_sz>
6874
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
6975
const sycl::vec<argT2, vec_sz> &in2)
7076
{
71-
if constexpr (std::is_same_v<argT1, argT2>) {
77+
if constexpr (std::is_same_v<argT1, argT2> && std::is_unsigned_v<argT2>)
78+
{
7279
return (in1 << in2);
7380
}
7481
else {
75-
sycl::vec<argT1, vec_sz> res;
82+
sycl::vec<resT, vec_sz> res;
7683
#pragma unroll
7784
for (int i = 0; i < vec_sz; ++i) {
78-
res[i] = (in1[i] << in2[i]);
85+
res[i] = (in2[i] < argT2(0)) ? resT(0) : (in1[i] << in2[i]);
7986
}
8087
return res;
8188
}
@@ -103,19 +110,52 @@ using BitwiseLeftShiftStridedFunctor = elementwise_common::BinaryStridedFunctor<
103110
IndexerT,
104111
BitwiseLeftShiftFunctor<argT1, argT2, resT>>;
105112

106-
using elementwise_common::is_integral_not_boolean;
107-
108113
template <typename T1, typename T2> struct BitwiseLeftShiftOutputType
109114
{
110115
using ResT = T1;
111116
using value_type = typename std::disjunction< // disjunction is C++17
112117
// feature, supported by
113118
// DPC++
114-
td_ns::BinaryConditionalTypeMapResultEntry<T1,
115-
is_integral_not_boolean,
116-
T2,
117-
is_integral_not_boolean,
118-
ResT>,
119+
td_ns::BinaryTypeMapResultEntry<T1,
120+
std::int8_t,
121+
T2,
122+
std::int8_t,
123+
std::int8_t>,
124+
td_ns::BinaryTypeMapResultEntry<T1,
125+
std::uint8_t,
126+
T2,
127+
std::uint8_t,
128+
std::uint8_t>,
129+
td_ns::BinaryTypeMapResultEntry<T1,
130+
std::int16_t,
131+
T2,
132+
std::int16_t,
133+
std::int16_t>,
134+
td_ns::BinaryTypeMapResultEntry<T1,
135+
std::uint16_t,
136+
T2,
137+
std::uint16_t,
138+
std::uint16_t>,
139+
td_ns::BinaryTypeMapResultEntry<T1,
140+
std::int32_t,
141+
T2,
142+
std::int32_t,
143+
std::int32_t>,
144+
td_ns::BinaryTypeMapResultEntry<T1,
145+
std::uint32_t,
146+
T2,
147+
std::uint32_t,
148+
std::uint32_t>,
149+
td_ns::BinaryTypeMapResultEntry<T1,
150+
std::int64_t,
151+
T2,
152+
std::int64_t,
153+
std::int64_t>,
154+
td_ns::BinaryTypeMapResultEntry<T1,
155+
std::uint64_t,
156+
T2,
157+
std::uint64_t,
158+
std::uint64_t>,
119159
td_ns::DefaultResultEntry<void>>::result_type;
120160
};
121161

dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,27 @@ struct BitwiseRightShiftFunctor
6161

6262
resT operator()(const argT1 &in1, const argT2 &in2)
6363
{
64-
return (in1 >> in2);
64+
if constexpr (std::is_unsigned_v<argT2>) {
65+
return (in1 >> in2);
66+
}
67+
else {
68+
return (in2 < argT2(0)) ? resT(0) : (in1 >> in2);
69+
}
6570
}
6671

6772
template <int vec_sz>
6873
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
6974
const sycl::vec<argT2, vec_sz> &in2)
7075
{
71-
if constexpr (std::is_same_v<argT1, argT2>) {
76+
if constexpr (std::is_same_v<argT1, argT2> && std::is_unsigned_v<argT2>)
77+
{
7278
return (in1 >> in2);
7379
}
7480
else {
75-
sycl::vec<argT1, vec_sz> res;
81+
sycl::vec<resT, vec_sz> res;
7682
#pragma unroll
7783
for (int i = 0; i < vec_sz; ++i) {
78-
res[i] = (in1[i] >> in2[i]);
84+
res[i] = (in2[i] < argT2(0)) ? resT(0) : (in1[i] >> in2[i]);
7985
}
8086
return res;
8187
}
@@ -104,19 +110,52 @@ using BitwiseRightShiftStridedFunctor =
104110
IndexerT,
105111
BitwiseRightShiftFunctor<argT1, argT2, resT>>;
106112

107-
using elementwise_common::is_integral_not_boolean;
108-
109113
template <typename T1, typename T2> struct BitwiseRightShiftOutputType
110114
{
111115
using ResT = T1;
112116
using value_type = typename std::disjunction< // disjunction is C++17
113117
// feature, supported by
114118
// DPC++
115-
td_ns::BinaryConditionalTypeMapResultEntry<T1,
116-
is_integral_not_boolean,
117-
T2,
118-
is_integral_not_boolean,
119-
ResT>,
119+
td_ns::BinaryTypeMapResultEntry<T1,
120+
std::int8_t,
121+
T2,
122+
std::int8_t,
123+
std::int8_t>,
124+
td_ns::BinaryTypeMapResultEntry<T1,
125+
std::uint8_t,
126+
T2,
127+
std::uint8_t,
128+
std::uint8_t>,
129+
td_ns::BinaryTypeMapResultEntry<T1,
130+
std::int16_t,
131+
T2,
132+
std::int16_t,
133+
std::int16_t>,
134+
td_ns::BinaryTypeMapResultEntry<T1,
135+
std::uint16_t,
136+
T2,
137+
std::uint16_t,
138+
std::uint16_t>,
139+
td_ns::BinaryTypeMapResultEntry<T1,
140+
std::int32_t,
141+
T2,
142+
std::int32_t,
143+
std::int32_t>,
144+
td_ns::BinaryTypeMapResultEntry<T1,
145+
std::uint32_t,
146+
T2,
147+
std::uint32_t,
148+
std::uint32_t>,
149+
td_ns::BinaryTypeMapResultEntry<T1,
150+
std::int64_t,
151+
T2,
152+
std::int64_t,
153+
std::int64_t>,
154+
td_ns::BinaryTypeMapResultEntry<T1,
155+
std::uint64_t,
156+
T2,
157+
std::uint64_t,
158+
std::uint64_t>,
120159
td_ns::DefaultResultEntry<void>>::result_type;
121160
};
122161

dpctl/tensor/libtensor/include/utils/type_dispatch.hpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -271,21 +271,6 @@ struct BinaryTypeMapResultEntry
271271
using result_type = ResTy;
272272
};
273273

274-
/*! @brief struct to define result_type typename for Ty1 == ArgTy1 && Ty2 ==
275-
* ArgTy2 */
276-
template <typename Ty1,
277-
template <typename T>
278-
class Cond1,
279-
typename Ty2,
280-
template <typename T>
281-
class Cond2,
282-
typename ResTy>
283-
struct BinaryConditionalTypeMapResultEntry
284-
: std::bool_constant<std::conjunction_v<Cond1<Ty1>, Cond2<Ty2>>>
285-
{
286-
using result_type = ResTy;
287-
};
288-
289274
/*! @brief fall-through struct with specified result_type, usually void */
290275
template <typename Ty = void> struct DefaultResultEntry : std::true_type
291276
{

0 commit comments

Comments
 (0)