Skip to content

Linspace should not use double type unconditionally #878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,15 @@ def _copy_from_numpy_into(dst, np_ary):
if not isinstance(dst, dpt.usm_ndarray):
raise TypeError("Expected usm_ndarray, got {}".format(type(dst)))
src_ary = np.broadcast_to(np_ary, dst.shape)
copy_q = dst.sycl_queue
if copy_q.sycl_device.has_aspect_fp64 is False:
src_ary_dt_c = src_ary.dtype.char
if src_ary_dt_c == "d":
src_ary = src_ary.astype(np.float32)
elif src_ary_dt_c == "D":
src_ary = src_ary.astype(np.complex64)
ti._copy_numpy_ndarray_into_usm_ndarray(
src=src_ary, dst=dst, sycl_queue=dst.sycl_queue
src=src_ary, dst=dst, sycl_queue=copy_q
)


Expand Down
28 changes: 19 additions & 9 deletions dpctl/tensor/libtensor/source/tensor_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ template <typename srcT, typename dstT> class copy_cast_from_host_kernel;
template <typename srcT, typename dstT, int nd> class copy_cast_spec_kernel;
template <typename Ty> class copy_for_reshape_generic_kernel;
template <typename Ty> class linear_sequence_step_kernel;
template <typename Ty> class linear_sequence_affine_kernel;
template <typename Ty, typename wTy> class linear_sequence_affine_kernel;

static dpctl::tensor::detail::usm_ndarray_types array_types;

Expand Down Expand Up @@ -1526,7 +1526,7 @@ typedef sycl::event (*lin_space_affine_fn_ptr_t)(
static lin_space_affine_fn_ptr_t
lin_space_affine_dispatch_vector[_ns::num_types];

template <typename Ty> class LinearSequenceAffineFunctor
template <typename Ty, typename wTy> class LinearSequenceAffineFunctor
{
private:
Ty *p = nullptr;
Expand All @@ -1544,8 +1544,8 @@ template <typename Ty> class LinearSequenceAffineFunctor
void operator()(sycl::id<1> wiid) const
{
auto i = wiid.get(0);
double wc = double(i) / n;
double w = double(n - i) / n;
wTy wc = wTy(i) / n;
wTy w = wTy(n - i) / n;
if constexpr (is_complex<Ty>::value) {
auto _w = static_cast<typename Ty::value_type>(w);
auto _wc = static_cast<typename Ty::value_type>(wc);
Expand Down Expand Up @@ -1578,13 +1578,23 @@ sycl::event lin_space_affine_impl(sycl::queue exec_q,
throw;
}

bool device_supports_doubles = exec_q.get_device().has(sycl::aspect::fp64);
sycl::event lin_space_affine_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.parallel_for<linear_sequence_affine_kernel<Ty>>(
sycl::range<1>{nelems},
LinearSequenceAffineFunctor<Ty>(array_data, start_v, end_v,
(include_endpoint) ? nelems - 1
: nelems));
if (device_supports_doubles) {
cgh.parallel_for<linear_sequence_affine_kernel<Ty, double>>(
sycl::range<1>{nelems},
LinearSequenceAffineFunctor<Ty, double>(
array_data, start_v, end_v,
(include_endpoint) ? nelems - 1 : nelems));
}
else {
cgh.parallel_for<linear_sequence_affine_kernel<Ty, float>>(
sycl::range<1>{nelems},
LinearSequenceAffineFunctor<Ty, float>(
array_data, start_v, end_v,
(include_endpoint) ? nelems - 1 : nelems));
}
});

return lin_space_affine_event;
Expand Down
2 changes: 2 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,8 @@ def test_linspace(dt):
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Default queue could not be created")
if dt in ["f8", "c16"] and not q.sycl_device.has_aspect_fp64:
pytest.skip("Device does not support double precision")
X = dpt.linspace(0, 1, num=2, dtype=dt, sycl_queue=q)
assert np.allclose(dpt.asnumpy(X), np.linspace(0, 1, num=2, dtype=dt))

Expand Down
42 changes: 21 additions & 21 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_expand_dims_tuple(axes):
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")

Xnp = np.empty((3, 3, 3))
Xnp = np.empty((3, 3, 3), dtype="u1")
X = dpt.asarray(Xnp, sycl_queue=q)
Y = dpt.expand_dims(X, axes)
Ynp = np.expand_dims(Xnp, axes)
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_squeeze_without_axes(shapes):
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")

Xnp = np.empty(shapes)
Xnp = np.empty(shapes, dtype="u1")
X = dpt.asarray(Xnp, sycl_queue=q)
Y = dpt.squeeze(X)
Ynp = Xnp.squeeze()
Expand All @@ -248,7 +248,7 @@ def test_squeeze_axes_arg(axes):
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")

Xnp = np.array([[[1], [2], [3]]])
Xnp = np.array([[[1], [2], [3]]], dtype="u1")
X = dpt.asarray(Xnp, sycl_queue=q)
Y = dpt.squeeze(X, axes)
Ynp = Xnp.squeeze(axes)
Expand All @@ -262,29 +262,29 @@ def test_squeeze_axes_arg_error(axes):
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")

Xnp = np.array([[[1], [2], [3]]])
Xnp = np.array([[[1], [2], [3]]], dtype="u1")
X = dpt.asarray(Xnp, sycl_queue=q)
pytest.raises(ValueError, dpt.squeeze, X, axes)


@pytest.mark.parametrize(
"data",
[
[np.array(0), (0,)],
[np.array(0), (1,)],
[np.array(0), (3,)],
[np.ones(1), (1,)],
[np.ones(1), (2,)],
[np.ones(1), (1, 2, 3)],
[np.arange(3), (3,)],
[np.arange(3), (1, 3)],
[np.arange(3), (2, 3)],
[np.ones(0), 0],
[np.ones(1), 1],
[np.ones(1), 2],
[np.ones(1), (0,)],
[np.ones((1, 2)), (0, 2)],
[np.ones((2, 1)), (2, 0)],
[np.array(0, dtype="u1"), (0,)],
[np.array(0, dtype="u1"), (1,)],
[np.array(0, dtype="u1"), (3,)],
[np.ones(1, dtype="u1"), (1,)],
[np.ones(1, dtype="u1"), (2,)],
[np.ones(1, dtype="u1"), (1, 2, 3)],
[np.arange(3, dtype="u1"), (3,)],
[np.arange(3, dtype="u1"), (1, 3)],
[np.arange(3, dtype="u1"), (2, 3)],
[np.ones(0, dtype="u1"), 0],
[np.ones(1, dtype="u1"), 1],
[np.ones(1, dtype="u1"), 2],
[np.ones(1, dtype="u1"), (0,)],
[np.ones((1, 2), dtype="u1"), (0, 2)],
[np.ones((2, 1), dtype="u1"), (2, 0)],
],
)
def test_broadcast_to_succeeds(data):
Expand Down Expand Up @@ -323,7 +323,7 @@ def test_broadcast_to_raises(data):
pytest.skip("Queue could not be created")

orig_shape, target_shape = data
Xnp = np.zeros(orig_shape)
Xnp = np.zeros(orig_shape, dtype="i1")
X = dpt.asarray(Xnp, sycl_queue=q)
pytest.raises(ValueError, dpt.broadcast_to, X, target_shape)

Expand All @@ -333,7 +333,7 @@ def assert_broadcast_correct(input_shapes):
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")
np_arrays = [np.zeros(s) for s in input_shapes]
np_arrays = [np.zeros(s, dtype="i1") for s in input_shapes]
out_np_arrays = np.broadcast_arrays(*np_arrays)
usm_arrays = [dpt.asarray(Xnp, sycl_queue=q) for Xnp in np_arrays]
out_usm_arrays = dpt.broadcast_arrays(*usm_arrays)
Expand Down