Skip to content

Commit 68e13d4

Browse files
authored
[SYCL] Add vec assignment from scalar and more vec modulus overloads (#9031)
`vec` now supports: - `swizzle = scalar` - `swizzle % swizzle` - `swizzle % scalar` - `swizzle % vec` - `scalar % swizzle` - `scalar % vec` Fixes #8881 and #8877 --------- Signed-off-by: Cai, Justin <justin.cai@intel.com>
1 parent 67da385 commit 68e13d4

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

sycl/include/sycl/types.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,16 @@ class SwizzleOp {
16181618
return *this;
16191619
}
16201620
1621+
template <int IdxNum = getNumElements(),
1622+
EnableIfMultipleIndexes<IdxNum, bool> = true>
1623+
SwizzleOp &operator=(const DataT &Rhs) {
1624+
std::array<int, IdxNum> Idxs{Indexes...};
1625+
for (auto Idx : Idxs) {
1626+
m_Vector->setValue(Idx, Rhs);
1627+
}
1628+
return *this;
1629+
}
1630+
16211631
template <int IdxNum = getNumElements(), typename = EnableIfOneIndex<IdxNum>>
16221632
SwizzleOp &operator=(DataT &&Rhs) {
16231633
std::array<int, IdxNum> Idxs{Indexes...};
@@ -1682,6 +1692,21 @@ class SwizzleOp {
16821692
Rhs);
16831693
}
16841694
1695+
template <typename T, typename = EnableIfScalarType<T>>
1696+
NewLHOp<GetScalarOp<T>, std::modulus, Indexes...>
1697+
operator%(const T &Rhs) const {
1698+
return NewLHOp<GetScalarOp<T>, std::modulus, Indexes...>(
1699+
m_Vector, *this, GetScalarOp<T>(Rhs));
1700+
}
1701+
1702+
template <typename RhsOperation,
1703+
typename = EnableIfNoScalarType<RhsOperation>>
1704+
NewLHOp<RhsOperation, std::modulus, Indexes...>
1705+
operator%(const RhsOperation &Rhs) const {
1706+
return NewLHOp<RhsOperation, std::modulus, Indexes...>(m_Vector, *this,
1707+
Rhs);
1708+
}
1709+
16851710
template <typename T, typename = EnableIfScalarType<T>>
16861711
NewLHOp<GetScalarOp<T>, std::bit_and, Indexes...>
16871712
operator&(const T &Rhs) const {
@@ -2054,6 +2079,7 @@ __SYCL_BINOP(+)
20542079
__SYCL_BINOP(-)
20552080
__SYCL_BINOP(*)
20562081
__SYCL_BINOP(/)
2082+
__SYCL_BINOP(%)
20572083
__SYCL_BINOP(&)
20582084
__SYCL_BINOP(|)
20592085
__SYCL_BINOP(^)

sycl/test/basic_tests/vectors/vectors.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,33 @@ template <typename From> void check_convert_from() {
5555
check_signed_unsigned_convert_to<From, double>();
5656
}
5757

58+
template <typename T, typename OpT> void check_ops(OpT op, T c1, T c2) {
59+
auto check = [&](sycl::vec<T, 2> vres) {
60+
assert(op(c1, c2) == vres[0]);
61+
assert(op(c1, c2) == vres[1]);
62+
};
63+
64+
sycl::vec<T, 2> v1(c1);
65+
sycl::vec<T, 2> v2(c2);
66+
check(op(v1.template swizzle<0, 1>(), v2.template swizzle<0, 1>()));
67+
check(op(v1.template swizzle<0, 1>(), v2));
68+
check(op(v1.template swizzle<0, 1>(), c2));
69+
check(op(c1, v2.template swizzle<0, 1>()));
70+
check(op(c1, v2));
71+
check(op(v1, v2.template swizzle<0, 1>()));
72+
check(op(v1, v2));
73+
check(op(v1, c2));
74+
75+
sycl::vec<T, 2> v3 = {c1, c2};
76+
sycl::vec<T, 2> v4 = op(v3, v3.template swizzle<1, 0>());
77+
assert(v4[0] == op(c1, c2) && v4[1] == op(c2, c1));
78+
sycl::vec<T, 2> v5 = op(v3.template swizzle<1, 1>(), v3);
79+
assert(v5[0] == op(c2, c1) && v5[1] == op(c2, c2));
80+
sycl::vec<T, 2> v6 =
81+
op(v3.template swizzle<1, 1>(), v3.template swizzle<0, 0>());
82+
assert(v6[0] == op(c2, c1) && v6[1] == op(c2, c1));
83+
}
84+
5885
int main() {
5986
sycl::int4 a = {1, 2, 3, 4};
6087
const sycl::int4 b = {10, 20, 30, 40};
@@ -91,6 +118,11 @@ int main() {
91118
assert(static_cast<float>(b_vec.y()) == static_cast<float>(0.5));
92119
assert(static_cast<float>(b_vec.z()) == static_cast<float>(0.5));
93120
assert(static_cast<float>(b_vec.w()) == static_cast<float>(0.5));
121+
b_vec.swizzle<0, 1, 2, 3>() = 0.6;
122+
assert(static_cast<float>(b_vec.x()) == static_cast<float>(0.6));
123+
assert(static_cast<float>(b_vec.y()) == static_cast<float>(0.6));
124+
assert(static_cast<float>(b_vec.z()) == static_cast<float>(0.6));
125+
assert(static_cast<float>(b_vec.w()) == static_cast<float>(0.6));
94126

95127
// Check that vector with 'unsigned long long' elements has enough bits to
96128
// store value.
@@ -142,5 +174,7 @@ int main() {
142174
check_convert_from<double>();
143175
check_convert_from<bool>();
144176

177+
check_ops<int>(std::modulus(), 6, 3);
178+
145179
return 0;
146180
}

0 commit comments

Comments
 (0)