Skip to content

[SYCL] improve operator forwarding in annotated_ref #12140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,8 @@ T operator=(const annotated_ref& other) const;
a|
Equivalent to:
```c++
T tmp = other; // Reads from memory
// with annotations
T tmp = other.operator T(); // Reads from memory
// with annotations
*this = tmp; // Writes to memory
// with annotations
return T;
Expand All @@ -642,8 +642,8 @@ Return result by value.
Available only if the corresponding assignment operator OP is available for `T` taking a type of `O`.
Equivalent to:
```c++
T tmp = *this; // Reads from memory
// with annotations
T tmp = this->operator T(); // Reads from memory
// with annotations
tmp OP std::forward<O>(a);
*this = tmp; // Writes to memory
// with annotations
Expand All @@ -665,8 +665,8 @@ Return result by value.
Available only if the corresponding assignment operator OP is available for `T`.
Equivalent to:
```c++
T tmp = *this; // Reads from memory
// with annotations
T tmp = this->operator T(); // Reads from memory
// with annotations
T tmp2 = b; // Reads from memory
// with annotations
tmp OP b;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,23 @@ class annotated_ref<T, detail::properties_t<Props...>> {

template <class O, class P>
T operator=(const annotated_ref<O, P> &Ref) const {
O t2 = Ref;
O t2 = Ref.operator O();
return *this = t2;
}

// propagate compound operators
#define PROPAGATE_OP(op) \
template <class O, typename = std::enable_if_t<!detail::is_ann_ref_v<O>>> \
T operator op(O &&rhs) const { \
T t = *this; \
T t = this->operator T(); \
t op std::forward<O>(rhs); \
*this = t; \
return t; \
} \
template <class O, class P> \
T operator op(const annotated_ref<O, P> &rhs) const { \
T t = *this; \
O t2 = rhs; \
T t = this->operator T(); \
O t2 = rhs.operator T(); \
t op t2; \
*this = t; \
return t; \
Expand All @@ -158,12 +158,12 @@ class annotated_ref<T, detail::properties_t<Props...>> {
template <class O> \
friend auto operator op(O &&a, const annotated_ref &b) \
->decltype(std::forward<O>(a) op std::declval<T>()) { \
return std::forward<O>(a) op T(b); \
return std::forward<O>(a) op b.operator T(); \
} \
template <class O, typename = std::enable_if_t<!detail::is_ann_ref_v<O>>> \
friend auto operator op(const annotated_ref &a, O &&b) \
->decltype(std::declval<T>() op std::forward<O>(b)) { \
return T(a) op std::forward<O>(b); \
return a.operator T() op std::forward<O>(b); \
}
PROPAGATE_OP(+)
PROPAGATE_OP(-)
Expand All @@ -190,7 +190,7 @@ class annotated_ref<T, detail::properties_t<Props...>> {
#define PROPAGATE_OP(op) \
template <typename O = T> \
auto operator op() const->decltype(op std::declval<O>()) { \
return op O(*this); \
return op this->operator O(); \
}
PROPAGATE_OP(+)
PROPAGATE_OP(-)
Expand All @@ -200,29 +200,29 @@ class annotated_ref<T, detail::properties_t<Props...>> {

// Propagate inc/dec operators
T operator++() const {
T t = *this;
T t = this->operator T();
++t;
*this = t;
return t;
}

T operator++(int) const {
T t1 = *this;
T t1 = this->operator T();
T t2 = t1;
t2++;
*this = t2;
return t1;
}

T operator--() const {
T t = *this;
T t = this->operator T();
--t;
*this = t;
return t;
}

T operator--(int) const {
T t1 = *this;
T t1 = this->operator T();
T t2 = t1;
t2--;
*this = t2;
Expand Down
16 changes: 16 additions & 0 deletions sycl/test-e2e/Annotated_arg_ptr/annotated_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ template <typename T> struct MyThirdStruct {
float operator==(const T &rhs) const { return data == rhs.data ? 3.0 : 1.0; }
};

template <typename T> struct MyFourthStruct {
T p;

template <typename T2> MyFourthStruct(const T2 &p_) : p(p_) {}

template <typename T2> void operator=(const T2 &p_) {}

int operator+(const int &rhs) const { return 0; }
int operator+=(const int &rhs) const { return 0; }
};

MySecondStruct::operator MyStruct<int>() const { return MyStruct<int>(0); }

#define BINARY_OP(op) \
Expand Down Expand Up @@ -181,6 +192,11 @@ int main() {

auto *r10 = malloc_shared<int>(1, Q);

auto *r11 = malloc_shared<MyFourthStruct<int>>(1, Q);
annotated_ptr r11_ptr{r11};
auto r11_add = *r11_ptr + 1;
*r11_ptr += 1;

// testing return type of operators
int o1 = 0;
float o2 = 1.5;
Expand Down