Skip to content

[SYCLCompat] Fix vectorized_binary impl to make SYCLomatic migrated code run pass #16553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions sycl/doc/syclcompat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2090,6 +2090,25 @@ struct sub_sat {
} // namespace syclcompat
```

`vectorized_binary` also supports comparison operators from the standard library (`std::equal_to`, `std::not_equal_to`, etc)
and the semantics can be modified by changing the comparison operator template instantiation. For example:

```cpp
unsigned int Input1;
unsigned int Input2;
// initialize inputs...

// Performs comparison on sycl::ushort2, following sycl::vec semantics
// Returns unsigned int containing, per vector element, 0xFFFF if true, and 0x0000 if false
syclcompat::vectorized_binary<sycl::ushort2>(
Input1, Input2, std::equal_to<>());

// Performs element-wise comparison on unsigned short
// Returns unsigned int containing, per vector element, 1 if true, and 0 if false
syclcompat::vectorized_binary<sycl::ushort2>(
Input1, Input2, std::equal_to<unsigned short>());
```

The math header provides a set of functions to extend 32-bit operations
to 33 bit, and handle sign extension internally. There is support for `add`,
`sub`, `absdiff`, `min` and `max` operations. Each operation provides overloads
Expand Down
39 changes: 7 additions & 32 deletions sycl/include/syclcompat/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,34 +119,13 @@ class vectorized_binary {
}
};

// Vectorized_binary for logical operations
template <typename VecT, class BinaryOperation>
class vectorized_binary<
VecT, BinaryOperation,
std::enable_if_t<std::is_same_v<
bool, decltype(std::declval<BinaryOperation>()(
std::declval<typename VecT::element_type>(),
std::declval<typename VecT::element_type>()))>>> {
std::void_t<std::invoke_result_t<BinaryOperation, VecT, VecT>>> {
public:
inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) {
unsigned result = 0;
constexpr size_t elem_size = 8 * sizeof(typename VecT::element_type);
static_assert(elem_size < 32,
"Vector element size must be less than 4 bytes");
constexpr unsigned bool_mask = (1U << elem_size) - 1;

for (size_t i = 0; i < a.size(); ++i) {
bool comp_result = binary_op(a[i], b[i]);
result |= (comp_result ? bool_mask : 0U) << (i * elem_size);
}

VecT v4;
for (size_t i = 0; i < v4.size(); ++i) {
v4[i] = static_cast<typename VecT::element_type>(
(result >> (i * elem_size)) & bool_mask);
}

return v4;
return binary_op(a, b).template as<VecT>();
}
};

Expand Down Expand Up @@ -694,8 +673,9 @@ inline unsigned vectorized_unary(unsigned a, const UnaryOperation unary_op) {
template <typename VecT>
inline unsigned vectorized_sum_abs_diff(unsigned a, unsigned b) {
sycl::vec<unsigned, 1> v0{a}, v1{b};
auto v2 = v0.as<VecT>();
auto v3 = v1.as<VecT>();
// Need convert element type to wider signed type to avoid overflow.
auto v2 = v0.as<VecT>().template convert<int>();
auto v3 = v1.as<VecT>().template convert<int>();
auto v4 = sycl::abs_diff(v2, v3);
unsigned sum = 0;
for (size_t i = 0; i < v4.size(); ++i) {
Expand Down Expand Up @@ -1095,13 +1075,8 @@ inline unsigned vectorized_binary(unsigned a, unsigned b,
auto v3 = v1.as<VecT>();
auto v4 =
detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
if constexpr (!std::is_same_v<
bool, decltype(std::declval<BinaryOperation>()(
std::declval<typename VecT::element_type>(),
std::declval<typename VecT::element_type>()))>) {
if (need_relu)
v4 = relu(v4);
}
if (need_relu)
v4 = relu(v4);
v0 = v4.template as<sycl::vec<unsigned, 1>>();
return v0;
}
Expand Down
Loading
Loading