Skip to content

Commit dba1c20

Browse files
rolandschulzbader
authored andcommitted
[SYCL] Add buffer container constructor and CTAD (#773)
Prototype of a proposal to add a buffer constructor which takes a contiguous buffer as an argument to simplify usage. Signed-off-by: Roland Schulz <roland.schulz@intel.com>
1 parent d6aa11b commit dba1c20

File tree

4 files changed

+121
-19
lines changed

4 files changed

+121
-19
lines changed

sycl/include/CL/sycl/buffer.hpp

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,18 @@ class buffer {
3030
using allocator_type = AllocatorT;
3131
template <int dims>
3232
using EnableIfOneDimension = typename std::enable_if<1 == dims>::type;
33-
33+
// using same requirement for contiguous container as std::span
34+
template <class Container>
35+
using EnableIfContiguous =
36+
detail::void_t<detail::enable_if_t<std::is_convertible<
37+
detail::remove_pointer_t<decltype(
38+
std::declval<Container>().data())> (*)[],
39+
const T (*)[]>::value>,
40+
decltype(std::declval<Container>().size())>;
41+
template <class It>
42+
using EnableIfItInputIterator = detail::enable_if_t<
43+
std::is_convertible<typename std::iterator_traits<It>::iterator_category,
44+
std::input_iterator_tag>::value>;
3445
template <typename ItA, typename ItB>
3546
using EnableIfSameNonConstIterators =
3647
typename std::enable_if<std::is_same<ItA, ItB>::value &&
@@ -107,7 +118,8 @@ class buffer {
107118
}
108119

109120
template <class InputIterator, int N = dimensions,
110-
typename = EnableIfOneDimension<N>>
121+
typename = EnableIfOneDimension<N>,
122+
typename = EnableIfItInputIterator<InputIterator>>
111123
buffer(InputIterator first, InputIterator last, AllocatorT allocator,
112124
const property_list &propList = {})
113125
: Range(range<1>(std::distance(first, last))) {
@@ -117,7 +129,8 @@ class buffer {
117129
}
118130

119131
template <class InputIterator, int N = dimensions,
120-
typename = EnableIfOneDimension<N>>
132+
typename = EnableIfOneDimension<N>,
133+
typename = EnableIfItInputIterator<InputIterator>>
121134
buffer(InputIterator first, InputIterator last,
122135
const property_list &propList = {})
123136
: Range(range<1>(std::distance(first, last))) {
@@ -126,6 +139,26 @@ class buffer {
126139
detail::getNextPowerOfTwo(sizeof(T)), propList);
127140
}
128141

142+
// This constructor is a prototype for a future SYCL specification
143+
template <class Container, int N = dimensions,
144+
typename = EnableIfOneDimension<N>,
145+
typename = EnableIfContiguous<Container>>
146+
buffer(Container &container, AllocatorT allocator,
147+
const property_list &propList = {})
148+
: Range(range<1>(container.size())) {
149+
impl = std::make_shared<detail::buffer_impl<AllocatorT>>(
150+
container.data(), container.data() + container.size(),
151+
get_count() * sizeof(T), detail::getNextPowerOfTwo(sizeof(T)), propList,
152+
allocator);
153+
}
154+
155+
// This constructor is a prototype for a future SYCL specification
156+
template <class Container, int N = dimensions,
157+
typename = EnableIfOneDimension<N>,
158+
typename = EnableIfContiguous<Container>>
159+
buffer(Container &container, const property_list &propList = {})
160+
: buffer(container, {}, propList) {}
161+
129162
buffer(buffer<T, dimensions, AllocatorT> &b, const id<dimensions> &baseIndex,
130163
const range<dimensions> &subRange)
131164
: impl(b.impl), Range(subRange),
@@ -317,6 +350,30 @@ class buffer {
317350
return newRange[1] == parentRange[1] && newRange[2] == parentRange[2];
318351
}
319352
};
353+
354+
#ifdef __cpp_deduction_guides
355+
template <class InputIterator, class AllocatorT>
356+
buffer(InputIterator, InputIterator, AllocatorT, const property_list & = {})
357+
->buffer<typename std::iterator_traits<InputIterator>::value_type, 1,
358+
AllocatorT>;
359+
template <class InputIterator>
360+
buffer(InputIterator, InputIterator, const property_list & = {})
361+
->buffer<typename std::iterator_traits<InputIterator>::value_type, 1>;
362+
template <class Container, class AllocatorT>
363+
buffer(Container &, AllocatorT, const property_list & = {})
364+
->buffer<typename Container::value_type, 1, AllocatorT>;
365+
template <class Container>
366+
buffer(Container &, const property_list & = {})
367+
->buffer<typename Container::value_type, 1>;
368+
template <class T, int dimensions, class AllocatorT>
369+
buffer(const T *, const range<dimensions> &, AllocatorT,
370+
const property_list & = {})
371+
->buffer<T, dimensions, AllocatorT>;
372+
template <class T, int dimensions>
373+
buffer(const T *, const range<dimensions> &, const property_list & = {})
374+
->buffer<T, dimensions>;
375+
#endif // __cpp_deduction_guides
376+
320377
} // namespace sycl
321378
} // namespace cl
322379

sycl/include/CL/sycl/detail/stl_type_traits.hpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,9 @@ namespace cl {
1616
namespace sycl {
1717
namespace detail {
1818

19-
template <bool V> using bool_constant = std::integral_constant<bool, V>;
20-
21-
template <typename T>
22-
using allocator_value_type_t = typename std::allocator_traits<T>::value_type;
23-
24-
template <typename T>
25-
using allocator_pointer_t = typename std::allocator_traits<T>::pointer;
26-
19+
// Type traits identical to those in std in newer versions. Can be removed when
20+
// SYCL requires a newer version of the C++ standard.
21+
// C++14
2722
template <bool B, class T = void>
2823
using enable_if_t = typename std::enable_if<B, T>::type;
2924

@@ -40,6 +35,18 @@ using remove_reference_t = typename std::remove_reference<T>::type;
4035

4136
template <typename T> using add_pointer_t = typename std::add_pointer<T>::type;
4237

38+
// C++17
39+
template <bool V> using bool_constant = std::integral_constant<bool, V>;
40+
41+
template <class...> using void_t = void;
42+
43+
// Custom type traits
44+
template <typename T>
45+
using allocator_value_type_t = typename std::allocator_traits<T>::value_type;
46+
47+
template <typename T>
48+
using allocator_pointer_t = typename std::allocator_traits<T>::pointer;
49+
4350
template <typename T>
4451
using iterator_category_t = typename std::iterator_traits<T>::iterator_category;
4552

@@ -53,15 +60,13 @@ template <typename T>
5360
using iterator_to_const_type_t =
5461
std::is_const<typename std::remove_pointer<iterator_pointer_t<T>>::type>;
5562

56-
template <class...> using requirements_list = void;
57-
5863
// TODO Align with C++ named requirements: LegacyOutputIterator
5964
// https://en.cppreference.com/w/cpp/named_req/OutputIterator
6065
template <typename T>
6166
using output_iterator_requirements =
62-
requirements_list<iterator_category_t<T>,
63-
decltype(*std::declval<T>() =
64-
std::declval<iterator_value_type_t<T>>())>;
67+
void_t<iterator_category_t<T>,
68+
decltype(*std::declval<T>() =
69+
std::declval<iterator_value_type_t<T>>())>;
6570

6671
template <typename, typename = void> struct is_output_iterator {
6772
static constexpr bool value = false;

sycl/test/basic_tests/buffer/buffer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ int main() {
556556
std::vector<int> data2(10, -2);
557557
{
558558
buffer<int, 1> a(data1.data(), range<1>(10));
559-
buffer<int, 1> b(data2.data(), range<1>(10));
559+
buffer<int, 1> b(data2);
560560

561561
program prog(myQueue.get_context());
562562
prog.build_with_source("kernel void override_source(global int* Acc) "
@@ -581,7 +581,7 @@ int main() {
581581
std::vector<int> data2(10, -2);
582582
{
583583
buffer<int, 1> a(data1.data(), range<1>(10));
584-
buffer<int, 1> b(data2.data(), range<1>(10));
584+
buffer<int, 1> b(data2);
585585
accessor<int, 1, access::mode::read_write, access::target::global_buffer,
586586
access::placeholder::true_t>
587587
A(a);
@@ -609,7 +609,7 @@ int main() {
609609
std::vector<int> data2(10, -2);
610610
{
611611
buffer<int, 1> a(data1.data(), range<1>(10));
612-
buffer<int, 1> b(data2.data(), range<1>(10));
612+
buffer<int, 1> b(data2);
613613
accessor<int, 1, access::mode::read_write,
614614
access::target::global_buffer, access::placeholder::true_t>
615615
A(a);
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: %clangxx -std=c++17 -fsyntax-only -Xclang -verify %s
2+
// expected-no-diagnostics
3+
//==------------------- buffer_ctad.cpp - SYCL buffer CTAD test ----------------==//
4+
//
5+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6+
// See https://llvm.org/LICENSE.txt for license information.
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include <CL/sycl.hpp>
12+
#include <cassert>
13+
#include <memory>
14+
15+
using namespace cl::sycl;
16+
17+
int main() {
18+
std::vector<int> v(5, 1);
19+
const std::vector<int> cv(5, 1);
20+
buffer b1(v.data(), range<1>(5));
21+
static_assert(std::is_same_v<decltype(b1), buffer<int, 1>>);
22+
buffer b1a(v.data(), range<1>(5), std::allocator<int>());
23+
static_assert(
24+
std::is_same_v<decltype(b1a), buffer<int, 1, std::allocator<int>>>);
25+
buffer b1b(cv.data(), range<1>(5));
26+
static_assert(std::is_same_v<decltype(b1b), buffer<int, 1>>);
27+
buffer b1c(v.data(), range<2>(2, 2));
28+
static_assert(std::is_same_v<decltype(b1c), buffer<int, 2>>);
29+
buffer b2(v.begin(), v.end());
30+
static_assert(std::is_same_v<decltype(b2), buffer<int, 1>>);
31+
buffer b2a(v.cbegin(), v.cend());
32+
static_assert(std::is_same_v<decltype(b2a), buffer<int, 1>>);
33+
buffer b3(v);
34+
static_assert(std::is_same_v<decltype(b3), buffer<int, 1>>);
35+
buffer b3a(cv);
36+
static_assert(std::is_same_v<decltype(b3a), buffer<int, 1>>);
37+
shared_ptr_class<int> ptr{new int[5], [](int *p) { delete[] p; }};
38+
buffer b4(ptr, range<1>(5));
39+
static_assert(std::is_same_v<decltype(b4), buffer<int, 1>>);
40+
}

0 commit comments

Comments
 (0)