Skip to content

Commit b68b1e4

Browse files
Merge pull request #878 from IntelPython/linspace-should-not-use-double-type-unconditionally
Linspace should not use double type unconditionally
2 parents d2da805 + e363969 commit b68b1e4

File tree

4 files changed

+50
-31
lines changed

4 files changed

+50
-31
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,15 @@ def _copy_from_numpy_into(dst, np_ary):
8181
if not isinstance(dst, dpt.usm_ndarray):
8282
raise TypeError("Expected usm_ndarray, got {}".format(type(dst)))
8383
src_ary = np.broadcast_to(np_ary, dst.shape)
84+
copy_q = dst.sycl_queue
85+
if copy_q.sycl_device.has_aspect_fp64 is False:
86+
src_ary_dt_c = src_ary.dtype.char
87+
if src_ary_dt_c == "d":
88+
src_ary = src_ary.astype(np.float32)
89+
elif src_ary_dt_c == "D":
90+
src_ary = src_ary.astype(np.complex64)
8491
ti._copy_numpy_ndarray_into_usm_ndarray(
85-
src=src_ary, dst=dst, sycl_queue=dst.sycl_queue
92+
src=src_ary, dst=dst, sycl_queue=copy_q
8693
)
8794

8895

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ template <typename srcT, typename dstT> class copy_cast_from_host_kernel;
4343
template <typename srcT, typename dstT, int nd> class copy_cast_spec_kernel;
4444
template <typename Ty> class copy_for_reshape_generic_kernel;
4545
template <typename Ty> class linear_sequence_step_kernel;
46-
template <typename Ty> class linear_sequence_affine_kernel;
46+
template <typename Ty, typename wTy> class linear_sequence_affine_kernel;
4747

4848
static dpctl::tensor::detail::usm_ndarray_types array_types;
4949

@@ -1526,7 +1526,7 @@ typedef sycl::event (*lin_space_affine_fn_ptr_t)(
15261526
static lin_space_affine_fn_ptr_t
15271527
lin_space_affine_dispatch_vector[_ns::num_types];
15281528

1529-
template <typename Ty> class LinearSequenceAffineFunctor
1529+
template <typename Ty, typename wTy> class LinearSequenceAffineFunctor
15301530
{
15311531
private:
15321532
Ty *p = nullptr;
@@ -1544,8 +1544,8 @@ template <typename Ty> class LinearSequenceAffineFunctor
15441544
void operator()(sycl::id<1> wiid) const
15451545
{
15461546
auto i = wiid.get(0);
1547-
double wc = double(i) / n;
1548-
double w = double(n - i) / n;
1547+
wTy wc = wTy(i) / n;
1548+
wTy w = wTy(n - i) / n;
15491549
if constexpr (is_complex<Ty>::value) {
15501550
auto _w = static_cast<typename Ty::value_type>(w);
15511551
auto _wc = static_cast<typename Ty::value_type>(wc);
@@ -1578,13 +1578,23 @@ sycl::event lin_space_affine_impl(sycl::queue exec_q,
15781578
throw;
15791579
}
15801580

1581+
bool device_supports_doubles = exec_q.get_device().has(sycl::aspect::fp64);
15811582
sycl::event lin_space_affine_event = exec_q.submit([&](sycl::handler &cgh) {
15821583
cgh.depends_on(depends);
1583-
cgh.parallel_for<linear_sequence_affine_kernel<Ty>>(
1584-
sycl::range<1>{nelems},
1585-
LinearSequenceAffineFunctor<Ty>(array_data, start_v, end_v,
1586-
(include_endpoint) ? nelems - 1
1587-
: nelems));
1584+
if (device_supports_doubles) {
1585+
cgh.parallel_for<linear_sequence_affine_kernel<Ty, double>>(
1586+
sycl::range<1>{nelems},
1587+
LinearSequenceAffineFunctor<Ty, double>(
1588+
array_data, start_v, end_v,
1589+
(include_endpoint) ? nelems - 1 : nelems));
1590+
}
1591+
else {
1592+
cgh.parallel_for<linear_sequence_affine_kernel<Ty, float>>(
1593+
sycl::range<1>{nelems},
1594+
LinearSequenceAffineFunctor<Ty, float>(
1595+
array_data, start_v, end_v,
1596+
(include_endpoint) ? nelems - 1 : nelems));
1597+
}
15881598
});
15891599

15901600
return lin_space_affine_event;

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,8 @@ def test_linspace(dt):
10801080
q = dpctl.SyclQueue()
10811081
except dpctl.SyclQueueCreationError:
10821082
pytest.skip("Default queue could not be created")
1083+
if dt in ["f8", "c16"] and not q.sycl_device.has_aspect_fp64:
1084+
pytest.skip("Device does not support double precision")
10831085
X = dpt.linspace(0, 1, num=2, dtype=dt, sycl_queue=q)
10841086
assert np.allclose(dpt.asnumpy(X), np.linspace(0, 1, num=2, dtype=dt))
10851087

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_expand_dims_tuple(axes):
152152
except dpctl.SyclQueueCreationError:
153153
pytest.skip("Queue could not be created")
154154

155-
Xnp = np.empty((3, 3, 3))
155+
Xnp = np.empty((3, 3, 3), dtype="u1")
156156
X = dpt.asarray(Xnp, sycl_queue=q)
157157
Y = dpt.expand_dims(X, axes)
158158
Ynp = np.expand_dims(Xnp, axes)
@@ -234,7 +234,7 @@ def test_squeeze_without_axes(shapes):
234234
except dpctl.SyclQueueCreationError:
235235
pytest.skip("Queue could not be created")
236236

237-
Xnp = np.empty(shapes)
237+
Xnp = np.empty(shapes, dtype="u1")
238238
X = dpt.asarray(Xnp, sycl_queue=q)
239239
Y = dpt.squeeze(X)
240240
Ynp = Xnp.squeeze()
@@ -248,7 +248,7 @@ def test_squeeze_axes_arg(axes):
248248
except dpctl.SyclQueueCreationError:
249249
pytest.skip("Queue could not be created")
250250

251-
Xnp = np.array([[[1], [2], [3]]])
251+
Xnp = np.array([[[1], [2], [3]]], dtype="u1")
252252
X = dpt.asarray(Xnp, sycl_queue=q)
253253
Y = dpt.squeeze(X, axes)
254254
Ynp = Xnp.squeeze(axes)
@@ -262,29 +262,29 @@ def test_squeeze_axes_arg_error(axes):
262262
except dpctl.SyclQueueCreationError:
263263
pytest.skip("Queue could not be created")
264264

265-
Xnp = np.array([[[1], [2], [3]]])
265+
Xnp = np.array([[[1], [2], [3]]], dtype="u1")
266266
X = dpt.asarray(Xnp, sycl_queue=q)
267267
pytest.raises(ValueError, dpt.squeeze, X, axes)
268268

269269

270270
@pytest.mark.parametrize(
271271
"data",
272272
[
273-
[np.array(0), (0,)],
274-
[np.array(0), (1,)],
275-
[np.array(0), (3,)],
276-
[np.ones(1), (1,)],
277-
[np.ones(1), (2,)],
278-
[np.ones(1), (1, 2, 3)],
279-
[np.arange(3), (3,)],
280-
[np.arange(3), (1, 3)],
281-
[np.arange(3), (2, 3)],
282-
[np.ones(0), 0],
283-
[np.ones(1), 1],
284-
[np.ones(1), 2],
285-
[np.ones(1), (0,)],
286-
[np.ones((1, 2)), (0, 2)],
287-
[np.ones((2, 1)), (2, 0)],
273+
[np.array(0, dtype="u1"), (0,)],
274+
[np.array(0, dtype="u1"), (1,)],
275+
[np.array(0, dtype="u1"), (3,)],
276+
[np.ones(1, dtype="u1"), (1,)],
277+
[np.ones(1, dtype="u1"), (2,)],
278+
[np.ones(1, dtype="u1"), (1, 2, 3)],
279+
[np.arange(3, dtype="u1"), (3,)],
280+
[np.arange(3, dtype="u1"), (1, 3)],
281+
[np.arange(3, dtype="u1"), (2, 3)],
282+
[np.ones(0, dtype="u1"), 0],
283+
[np.ones(1, dtype="u1"), 1],
284+
[np.ones(1, dtype="u1"), 2],
285+
[np.ones(1, dtype="u1"), (0,)],
286+
[np.ones((1, 2), dtype="u1"), (0, 2)],
287+
[np.ones((2, 1), dtype="u1"), (2, 0)],
288288
],
289289
)
290290
def test_broadcast_to_succeeds(data):
@@ -323,7 +323,7 @@ def test_broadcast_to_raises(data):
323323
pytest.skip("Queue could not be created")
324324

325325
orig_shape, target_shape = data
326-
Xnp = np.zeros(orig_shape)
326+
Xnp = np.zeros(orig_shape, dtype="i1")
327327
X = dpt.asarray(Xnp, sycl_queue=q)
328328
pytest.raises(ValueError, dpt.broadcast_to, X, target_shape)
329329

@@ -333,7 +333,7 @@ def assert_broadcast_correct(input_shapes):
333333
q = dpctl.SyclQueue()
334334
except dpctl.SyclQueueCreationError:
335335
pytest.skip("Queue could not be created")
336-
np_arrays = [np.zeros(s) for s in input_shapes]
336+
np_arrays = [np.zeros(s, dtype="i1") for s in input_shapes]
337337
out_np_arrays = np.broadcast_arrays(*np_arrays)
338338
usm_arrays = [dpt.asarray(Xnp, sycl_queue=q) for Xnp in np_arrays]
339339
out_usm_arrays = dpt.broadcast_arrays(*usm_arrays)

0 commit comments

Comments
 (0)