Skip to content

Commit 47e4ae4

Browse files
Merge pull request #1032 from IntelPython/simplify-iteration-tweak
Simplify iteration tweak
2 parents afa1b86 + eb153e6 commit 47e4ae4

File tree

3 files changed

+88
-35
lines changed

3 files changed

+88
-35
lines changed

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ template <typename srcT, typename dstT> class Caster
5151
{
5252
public:
5353
Caster() = default;
54-
void operator()(char *src,
54+
void operator()(const char *src,
5555
std::ptrdiff_t src_offset,
5656
char *dst,
5757
std::ptrdiff_t dst_offset) const
5858
{
5959
using dpctl::tensor::type_utils::convert_impl;
6060

61-
srcT *src_ = reinterpret_cast<srcT *>(src) + src_offset;
61+
const srcT *src_ = reinterpret_cast<const srcT *>(src) + src_offset;
6262
dstT *dst_ = reinterpret_cast<dstT *>(dst) + dst_offset;
6363
*dst_ = convert_impl<dstT, srcT>(*src_);
6464
}
@@ -67,17 +67,17 @@ template <typename srcT, typename dstT> class Caster
6767
template <typename CastFnT> class GenericCopyFunctor
6868
{
6969
private:
70-
char *src_ = nullptr;
70+
const char *src_ = nullptr;
7171
char *dst_ = nullptr;
72-
py::ssize_t *shape_strides_ = nullptr;
72+
const py::ssize_t *shape_strides_ = nullptr;
7373
int nd_ = 0;
7474
py::ssize_t src_offset0 = 0;
7575
py::ssize_t dst_offset0 = 0;
7676

7777
public:
78-
GenericCopyFunctor(char *src_cp,
78+
GenericCopyFunctor(const char *src_cp,
7979
char *dst_cp,
80-
py::ssize_t *shape_strides,
80+
const py::ssize_t *shape_strides,
8181
int nd,
8282
py::ssize_t src_offset,
8383
py::ssize_t dst_offset)
@@ -93,13 +93,11 @@ template <typename CastFnT> class GenericCopyFunctor
9393
CIndexer_vector<py::ssize_t> indxr(nd_);
9494
indxr.get_displacement<const py::ssize_t *, const py::ssize_t *>(
9595
static_cast<py::ssize_t>(wiid.get(0)),
96-
const_cast<const py::ssize_t *>(shape_strides_), // common shape
97-
const_cast<const py::ssize_t *>(shape_strides_ +
98-
nd_), // src strides
99-
const_cast<const py::ssize_t *>(shape_strides_ +
100-
2 * nd_), // dst strides
101-
src_offset, // modified by reference
102-
dst_offset // modified by reference
96+
shape_strides_, // common shape
97+
shape_strides_ + nd_, // src strides
98+
shape_strides_ + 2 * nd_, // dst strides
99+
src_offset, // modified by reference
100+
dst_offset // modified by reference
103101
);
104102
CastFnT fn{};
105103
fn(src_, src_offset0 + src_offset, dst_, dst_offset0 + dst_offset);
@@ -109,7 +107,7 @@ template <typename CastFnT> class GenericCopyFunctor
109107
template <int nd, typename CastFnT> class NDSpecializedCopyFunctor
110108
{
111109
private:
112-
char *src_ = nullptr;
110+
const char *src_ = nullptr;
113111
char *dst_ = nullptr;
114112
CIndexer_array<nd, py::ssize_t> indxr;
115113
const std::array<py::ssize_t, nd> src_strides_;
@@ -119,8 +117,8 @@ template <int nd, typename CastFnT> class NDSpecializedCopyFunctor
119117
py::ssize_t dst_offset0 = 0;
120118

121119
public:
122-
NDSpecializedCopyFunctor(char *src_cp, // USM pointer
123-
char *dst_cp, // USM pointer
120+
NDSpecializedCopyFunctor(const char *src_cp, // USM pointer
121+
char *dst_cp, // USM pointer
124122
const std::array<py::ssize_t, nd> shape,
125123
const std::array<py::ssize_t, nd> src_strides,
126124
const std::array<py::ssize_t, nd> dst_strides,
@@ -140,8 +138,10 @@ template <int nd, typename CastFnT> class NDSpecializedCopyFunctor
140138

141139
local_indxr.set(wiid.get(0));
142140
auto mi = local_indxr.get();
141+
#pragma unroll
143142
for (int i = 0; i < nd; ++i)
144143
src_offset += mi[i] * src_strides_[i];
144+
#pragma unroll
145145
for (int i = 0; i < nd; ++i)
146146
dst_offset += mi[i] * dst_strides_[i];
147147

@@ -161,8 +161,8 @@ typedef sycl::event (*copy_and_cast_generic_fn_ptr_t)(
161161
sycl::queue,
162162
size_t,
163163
int,
164-
py::ssize_t *,
165-
char *,
164+
const py::ssize_t *,
165+
const char *,
166166
py::ssize_t,
167167
char *,
168168
py::ssize_t,
@@ -207,8 +207,8 @@ sycl::event
207207
copy_and_cast_generic_impl(sycl::queue q,
208208
size_t nelems,
209209
int nd,
210-
py::ssize_t *shape_and_strides,
211-
char *src_p,
210+
const py::ssize_t *shape_and_strides,
211+
const char *src_p,
212212
py::ssize_t src_offset,
213213
char *dst_p,
214214
py::ssize_t dst_offset,
@@ -256,7 +256,7 @@ typedef sycl::event (*copy_and_cast_1d_fn_ptr_t)(
256256
const std::array<py::ssize_t, 1>,
257257
const std::array<py::ssize_t, 1>,
258258
const std::array<py::ssize_t, 1>,
259-
char *,
259+
const char *,
260260
py::ssize_t,
261261
char *,
262262
py::ssize_t,
@@ -272,7 +272,7 @@ typedef sycl::event (*copy_and_cast_2d_fn_ptr_t)(
272272
const std::array<py::ssize_t, 2>,
273273
const std::array<py::ssize_t, 2>,
274274
const std::array<py::ssize_t, 2>,
275-
char *,
275+
const char *,
276276
py::ssize_t,
277277
char *,
278278
py::ssize_t,
@@ -314,7 +314,7 @@ copy_and_cast_nd_specialized_impl(sycl::queue q,
314314
const std::array<py::ssize_t, nd> shape,
315315
const std::array<py::ssize_t, nd> src_strides,
316316
const std::array<py::ssize_t, nd> dst_strides,
317-
char *src_p,
317+
const char *src_p,
318318
py::ssize_t src_offset,
319319
char *dst_p,
320320
py::ssize_t dst_offset,

dpctl/tensor/libtensor/source/simplify_iteration_space.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,6 @@ void simplify_iteration_space(int &nd,
120120
simplified_dst_strides.resize(contracted_nd);
121121

122122
nd = contracted_nd;
123-
shape = const_cast<const py::ssize_t *>(simplified_shape.data());
124-
src_strides =
125-
const_cast<const py::ssize_t *>(simplified_src_strides.data());
126-
dst_strides =
127-
const_cast<const py::ssize_t *>(simplified_dst_strides.data());
128123
}
129124
else if (nd == 1) {
130125
// Populate vectors
@@ -171,6 +166,11 @@ void simplify_iteration_space(int &nd,
171166
assert(simplified_src_strides.size() == static_cast<size_t>(nd));
172167
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
173168
}
169+
shape = const_cast<const py::ssize_t *>(simplified_shape.data());
170+
src_strides =
171+
const_cast<const py::ssize_t *>(simplified_src_strides.data());
172+
dst_strides =
173+
const_cast<const py::ssize_t *>(simplified_dst_strides.data());
174174
}
175175

176176
} // namespace py_internal

dpctl/tensor/libtensor/tests/test_copy.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@
4040
]
4141

4242

43+
def _typestr_has_fp64(arr_typestr):
44+
return arr_typestr in ["f8", "c16"]
45+
46+
47+
def _typestr_has_fp16(arr_typestr):
48+
return arr_typestr in ["f2"]
49+
50+
4351
@pytest.fixture(params=_usm_types_list)
4452
def usm_type(request):
4553
return request.param
@@ -95,6 +103,14 @@ def test_copy1d_c_contig(src_typestr, dst_typestr):
95103
q = dpctl.SyclQueue()
96104
except dpctl.SyclQueueCreationError:
97105
pytest.skip("Queue could not be created")
106+
if not q.sycl_device.has_aspect_fp64 and (
107+
_typestr_has_fp64(src_typestr) or _typestr_has_fp64(dst_typestr)
108+
):
109+
pytest.skip("Device does not support double precision")
110+
if not q.sycl_device.has_aspect_fp16 and (
111+
_typestr_has_fp16(src_typestr) or _typestr_has_fp16(dst_typestr)
112+
):
113+
pytest.skip("Device does not support half precision")
98114
src_dt = np.dtype(src_typestr)
99115
dst_dt = np.dtype(dst_typestr)
100116
Xnp = _random_vector(4096, src_dt)
@@ -113,6 +129,14 @@ def test_copy1d_strided(src_typestr, dst_typestr):
113129
q = dpctl.SyclQueue()
114130
except dpctl.SyclQueueCreationError:
115131
pytest.skip("Queue could not be created")
132+
if not q.sycl_device.has_aspect_fp64 and (
133+
_typestr_has_fp64(src_typestr) or _typestr_has_fp64(dst_typestr)
134+
):
135+
pytest.skip("Device does not support double precision")
136+
if not q.sycl_device.has_aspect_fp16 and (
137+
_typestr_has_fp16(src_typestr) or _typestr_has_fp16(dst_typestr)
138+
):
139+
pytest.skip("Device does not support half precision")
116140
src_dt = np.dtype(src_typestr)
117141
dst_dt = np.dtype(dst_typestr)
118142
Xnp = _random_vector(4096, src_dt)
@@ -131,7 +155,12 @@ def test_copy1d_strided(src_typestr, dst_typestr):
131155
assert are_close(Ynp, dpt.asnumpy(Y))
132156

133157
# now 0-strided source
134-
X = dpt.usm_ndarray((4096,), dtype=src_typestr, strides=(0,))
158+
X = dpt.usm_ndarray(
159+
(4096,),
160+
dtype=src_typestr,
161+
strides=(0,),
162+
buffer_ctor_kwargs={"queue": q},
163+
)
135164
X[0] = Xnp[0]
136165
Y = dpt.empty(X.shape, dtype=dst_typestr, sycl_queue=q)
137166
hev, ev = ti._copy_usm_ndarray_into_usm_ndarray(src=X, dst=Y, sycl_queue=q)
@@ -145,6 +174,14 @@ def test_copy1d_strided2(src_typestr, dst_typestr):
145174
q = dpctl.SyclQueue()
146175
except dpctl.SyclQueueCreationError:
147176
pytest.skip("Queue could not be created")
177+
if not q.sycl_device.has_aspect_fp64 and (
178+
_typestr_has_fp64(src_typestr) or _typestr_has_fp64(dst_typestr)
179+
):
180+
pytest.skip("Device does not support double precision")
181+
if not q.sycl_device.has_aspect_fp16 and (
182+
_typestr_has_fp16(src_typestr) or _typestr_has_fp16(dst_typestr)
183+
):
184+
pytest.skip("Device does not support half precision")
148185
src_dt = np.dtype(src_typestr)
149186
dst_dt = np.dtype(dst_typestr)
150187
Xnp = _random_vector(4096, src_dt)
@@ -172,6 +209,14 @@ def test_copy2d(src_typestr, dst_typestr, st1, sgn1, st2, sgn2):
172209
q = dpctl.SyclQueue()
173210
except dpctl.SyclQueueCreationError:
174211
pytest.skip("Queue could not be created")
212+
if not q.sycl_device.has_aspect_fp64 and (
213+
_typestr_has_fp64(src_typestr) or _typestr_has_fp64(dst_typestr)
214+
):
215+
pytest.skip("Device does not support double precision")
216+
if not q.sycl_device.has_aspect_fp16 and (
217+
_typestr_has_fp16(src_typestr) or _typestr_has_fp16(dst_typestr)
218+
):
219+
pytest.skip("Device does not support half precision")
175220

176221
src_dt = np.dtype(src_typestr)
177222
dst_dt = np.dtype(dst_typestr)
@@ -188,16 +233,16 @@ def test_copy2d(src_typestr, dst_typestr, st1, sgn1, st2, sgn2):
188233
slice(None, None, st1 * sgn1),
189234
slice(None, None, st2 * sgn2),
190235
]
191-
Y = dpt.empty((n1, n2), dtype=dst_dt)
236+
Y = dpt.empty((n1, n2), dtype=dst_dt, device=X.device)
192237
hev, ev = ti._copy_usm_ndarray_into_usm_ndarray(src=X, dst=Y, sycl_queue=q)
193238
Ynp = _force_cast(Xnp, dst_dt)
194239
hev.wait()
195240
assert are_close(Ynp, dpt.asnumpy(Y))
196-
Yst = dpt.empty((2 * n1, n2), dtype=dst_dt)[::2, ::-1]
241+
Yst = dpt.empty((2 * n1, n2), dtype=dst_dt, device=X.device)[::2, ::-1]
197242
hev, ev = ti._copy_usm_ndarray_into_usm_ndarray(
198243
src=X, dst=Yst, sycl_queue=q
199244
)
200-
Y = dpt.empty((n1, n2), dtype=dst_dt)
245+
Y = dpt.empty((n1, n2), dtype=dst_dt, device=X.device)
201246
hev2, ev2 = ti._copy_usm_ndarray_into_usm_ndarray(
202247
src=Yst, dst=Y, sycl_queue=q, depends=[ev]
203248
)
@@ -220,6 +265,14 @@ def test_copy3d(src_typestr, dst_typestr, st1, sgn1, st2, sgn2, st3, sgn3):
220265
except dpctl.SyclQueueCreationError:
221266
pytest.skip("Queue could not be created")
222267

268+
if not q.sycl_device.has_aspect_fp64 and (
269+
_typestr_has_fp64(src_typestr) or _typestr_has_fp64(dst_typestr)
270+
):
271+
pytest.skip("Device does not support double precision")
272+
if not q.sycl_device.has_aspect_fp16 and (
273+
_typestr_has_fp16(src_typestr) or _typestr_has_fp16(dst_typestr)
274+
):
275+
pytest.skip("Device does not support half precision")
223276
src_dt = np.dtype(src_typestr)
224277
dst_dt = np.dtype(dst_typestr)
225278
n1, n2, n3 = 5, 4, 6
@@ -237,16 +290,16 @@ def test_copy3d(src_typestr, dst_typestr, st1, sgn1, st2, sgn2, st3, sgn3):
237290
slice(None, None, st2 * sgn2),
238291
slice(None, None, st3 * sgn3),
239292
]
240-
Y = dpt.empty((n1, n2, n3), dtype=dst_dt)
293+
Y = dpt.empty((n1, n2, n3), dtype=dst_dt, device=X.device)
241294
hev, ev = ti._copy_usm_ndarray_into_usm_ndarray(src=X, dst=Y, sycl_queue=q)
242295
Ynp = _force_cast(Xnp, dst_dt)
243296
hev.wait()
244297
assert are_close(Ynp, dpt.asnumpy(Y)), "1"
245-
Yst = dpt.empty((2 * n1, n2, n3), dtype=dst_dt)[::2, ::-1]
298+
Yst = dpt.empty((2 * n1, n2, n3), dtype=dst_dt, device=X.device)[::2, ::-1]
246299
hev2, ev2 = ti._copy_usm_ndarray_into_usm_ndarray(
247300
src=X, dst=Yst, sycl_queue=q
248301
)
249-
Y2 = dpt.empty((n1, n2, n3), dtype=dst_dt)
302+
Y2 = dpt.empty((n1, n2, n3), dtype=dst_dt, device=X.device)
250303
hev3, ev3 = ti._copy_usm_ndarray_into_usm_ndarray(
251304
src=Yst, dst=Y2, sycl_queue=q, depends=[ev2]
252305
)

0 commit comments

Comments
 (0)