Skip to content

Commit eb153e6

Browse files
Added const qualifiers to read-only pointers for copy-and-cast kernels
This change made it possible to remove some uses of const_cast and made code simpler. Also used #pragma unroll in specialized CopyAndCast kernel where displacement is computed from multi-index.
1 parent 73efa36 commit eb153e6

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ template <typename srcT, typename dstT> class Caster
5151
{
5252
public:
5353
Caster() = default;
54-
void operator()(char *src,
54+
void operator()(const char *src,
5555
std::ptrdiff_t src_offset,
5656
char *dst,
5757
std::ptrdiff_t dst_offset) const
5858
{
5959
using dpctl::tensor::type_utils::convert_impl;
6060

61-
srcT *src_ = reinterpret_cast<srcT *>(src) + src_offset;
61+
const srcT *src_ = reinterpret_cast<const srcT *>(src) + src_offset;
6262
dstT *dst_ = reinterpret_cast<dstT *>(dst) + dst_offset;
6363
*dst_ = convert_impl<dstT, srcT>(*src_);
6464
}
@@ -67,17 +67,17 @@ template <typename srcT, typename dstT> class Caster
6767
template <typename CastFnT> class GenericCopyFunctor
6868
{
6969
private:
70-
char *src_ = nullptr;
70+
const char *src_ = nullptr;
7171
char *dst_ = nullptr;
72-
py::ssize_t *shape_strides_ = nullptr;
72+
const py::ssize_t *shape_strides_ = nullptr;
7373
int nd_ = 0;
7474
py::ssize_t src_offset0 = 0;
7575
py::ssize_t dst_offset0 = 0;
7676

7777
public:
78-
GenericCopyFunctor(char *src_cp,
78+
GenericCopyFunctor(const char *src_cp,
7979
char *dst_cp,
80-
py::ssize_t *shape_strides,
80+
const py::ssize_t *shape_strides,
8181
int nd,
8282
py::ssize_t src_offset,
8383
py::ssize_t dst_offset)
@@ -93,13 +93,11 @@ template <typename CastFnT> class GenericCopyFunctor
9393
CIndexer_vector<py::ssize_t> indxr(nd_);
9494
indxr.get_displacement<const py::ssize_t *, const py::ssize_t *>(
9595
static_cast<py::ssize_t>(wiid.get(0)),
96-
const_cast<const py::ssize_t *>(shape_strides_), // common shape
97-
const_cast<const py::ssize_t *>(shape_strides_ +
98-
nd_), // src strides
99-
const_cast<const py::ssize_t *>(shape_strides_ +
100-
2 * nd_), // dst strides
101-
src_offset, // modified by reference
102-
dst_offset // modified by reference
96+
shape_strides_, // common shape
97+
shape_strides_ + nd_, // src strides
98+
shape_strides_ + 2 * nd_, // dst strides
99+
src_offset, // modified by reference
100+
dst_offset // modified by reference
103101
);
104102
CastFnT fn{};
105103
fn(src_, src_offset0 + src_offset, dst_, dst_offset0 + dst_offset);
@@ -109,7 +107,7 @@ template <typename CastFnT> class GenericCopyFunctor
109107
template <int nd, typename CastFnT> class NDSpecializedCopyFunctor
110108
{
111109
private:
112-
char *src_ = nullptr;
110+
const char *src_ = nullptr;
113111
char *dst_ = nullptr;
114112
CIndexer_array<nd, py::ssize_t> indxr;
115113
const std::array<py::ssize_t, nd> src_strides_;
@@ -119,8 +117,8 @@ template <int nd, typename CastFnT> class NDSpecializedCopyFunctor
119117
py::ssize_t dst_offset0 = 0;
120118

121119
public:
122-
NDSpecializedCopyFunctor(char *src_cp, // USM pointer
123-
char *dst_cp, // USM pointer
120+
NDSpecializedCopyFunctor(const char *src_cp, // USM pointer
121+
char *dst_cp, // USM pointer
124122
const std::array<py::ssize_t, nd> shape,
125123
const std::array<py::ssize_t, nd> src_strides,
126124
const std::array<py::ssize_t, nd> dst_strides,
@@ -140,8 +138,10 @@ template <int nd, typename CastFnT> class NDSpecializedCopyFunctor
140138

141139
local_indxr.set(wiid.get(0));
142140
auto mi = local_indxr.get();
141+
#pragma unroll
143142
for (int i = 0; i < nd; ++i)
144143
src_offset += mi[i] * src_strides_[i];
144+
#pragma unroll
145145
for (int i = 0; i < nd; ++i)
146146
dst_offset += mi[i] * dst_strides_[i];
147147

@@ -161,8 +161,8 @@ typedef sycl::event (*copy_and_cast_generic_fn_ptr_t)(
161161
sycl::queue,
162162
size_t,
163163
int,
164-
py::ssize_t *,
165-
char *,
164+
const py::ssize_t *,
165+
const char *,
166166
py::ssize_t,
167167
char *,
168168
py::ssize_t,
@@ -207,8 +207,8 @@ sycl::event
207207
copy_and_cast_generic_impl(sycl::queue q,
208208
size_t nelems,
209209
int nd,
210-
py::ssize_t *shape_and_strides,
211-
char *src_p,
210+
const py::ssize_t *shape_and_strides,
211+
const char *src_p,
212212
py::ssize_t src_offset,
213213
char *dst_p,
214214
py::ssize_t dst_offset,
@@ -256,7 +256,7 @@ typedef sycl::event (*copy_and_cast_1d_fn_ptr_t)(
256256
const std::array<py::ssize_t, 1>,
257257
const std::array<py::ssize_t, 1>,
258258
const std::array<py::ssize_t, 1>,
259-
char *,
259+
const char *,
260260
py::ssize_t,
261261
char *,
262262
py::ssize_t,
@@ -272,7 +272,7 @@ typedef sycl::event (*copy_and_cast_2d_fn_ptr_t)(
272272
const std::array<py::ssize_t, 2>,
273273
const std::array<py::ssize_t, 2>,
274274
const std::array<py::ssize_t, 2>,
275-
char *,
275+
const char *,
276276
py::ssize_t,
277277
char *,
278278
py::ssize_t,
@@ -314,7 +314,7 @@ copy_and_cast_nd_specialized_impl(sycl::queue q,
314314
const std::array<py::ssize_t, nd> shape,
315315
const std::array<py::ssize_t, nd> src_strides,
316316
const std::array<py::ssize_t, nd> dst_strides,
317-
char *src_p,
317+
const char *src_p,
318318
py::ssize_t src_offset,
319319
char *dst_p,
320320
py::ssize_t dst_offset,

0 commit comments

Comments
 (0)