Skip to content

Commit f85a362

Browse files
npolina4oleksandr-pavlykantonwolfy
authored
Use linspace() function from dpctl.tensor. (#1281)
* Use linspace() function from dpctl.tensor * Convert file cupy/creation_tests/test_ranges.py to unix * Added support for array input arguments to linspace() function. * Updated linspace implementation for arrays as input argument. * Fixed linspace() function for complex dtype. * Removed extra copy in linspace() function. * Added comments for linspace() function. * Added skipping cross device tests for linspace() function on Windows. * Added reason for skipping tests for linspace() function. Co-authored-by: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com> --------- Co-authored-by: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com> Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
1 parent d1a90bc commit f85a362

File tree

11 files changed

+650
-459
lines changed

11 files changed

+650
-459
lines changed

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ cimport dpnp.dpnp_utils as utils
4848
cimport numpy
4949
import numpy
5050

51+
import operator
52+
5153

5254
__all__ = [
5355
"dpnp_astype",

dpnp/dpnp_algo/dpnp_algo_arraycreation.pyx

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -189,31 +189,81 @@ cpdef utils.dpnp_descriptor dpnp_identity(n, result_dtype):
189189
return result
190190

191191

192-
# TODO this function should work through dpnp_arange_c
193-
cpdef tuple dpnp_linspace(start, stop, num, endpoint, retstep, dtype, axis):
194-
cdef shape_type_c obj_shape = utils._object_to_tuple(num)
195-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(obj_shape, dtype, None)
192+
def dpnp_linspace(start, stop, num, dtype=None, device=None, usm_type=None, sycl_queue=None, endpoint=True, retstep=False, axis=0):
193+
usm_type_alloc, sycl_queue_alloc = utils_py.get_usm_allocations([start, stop])
196194

197-
if endpoint:
198-
steps_count = num - 1
199-
else:
200-
steps_count = num
195+
# Get sycl_queue.
196+
if sycl_queue is None and device is None:
197+
sycl_queue = sycl_queue_alloc
198+
sycl_queue_normalized = dpnp.get_normalized_queue_device(sycl_queue=sycl_queue, device=device)
201199

202-
# if there are steps, then fill values
203-
if steps_count > 0:
204-
step = (dpnp.float64(stop) - start) / steps_count
205-
for i in range(1, result.size):
206-
result.get_pyobj()[i] = start + step * i
200+
# Get temporary usm_type for getting dtype.
201+
if usm_type is None:
202+
_usm_type = "device" if usm_type_alloc is None else usm_type_alloc
207203
else:
208-
step = dpnp.nan
204+
_usm_type = usm_type
205+
206+
# Get dtype.
207+
if not hasattr(start, "dtype") and not dpnp.isscalar(start):
208+
start = dpnp.asarray(start, usm_type=_usm_type, sycl_queue=sycl_queue_normalized)
209+
if not hasattr(stop, "dtype") and not dpnp.isscalar(stop):
210+
stop = dpnp.asarray(stop, usm_type=_usm_type, sycl_queue=sycl_queue_normalized)
211+
dt = numpy.result_type(start, stop, float(num))
212+
dt = utils_py.map_dtype_to_device(dt, sycl_queue_normalized.sycl_device)
213+
if dtype is None:
214+
dtype = dt
215+
216+
if dpnp.isscalar(start) and dpnp.isscalar(stop):
217+
# Call linspace() function for scalars.
218+
res = dpnp_container.linspace(start,
219+
stop,
220+
num,
221+
dtype=dt,
222+
usm_type=_usm_type,
223+
sycl_queue=sycl_queue_normalized,
224+
endpoint=endpoint)
225+
else:
226+
num = operator.index(num)
227+
if num < 0:
228+
raise ValueError("Number of points must be non-negative")
229+
230+
# Get final usm_type and copy arrays if needed with current dtype, usm_type and sycl_queue.
231+
# Do not need to copy usm_ndarray by usm_type if it is not explicitly stated.
232+
if usm_type is None:
233+
usm_type = _usm_type
234+
if not hasattr(start, "usm_type"):
235+
_start = dpnp.asarray(start, dtype=dt, usm_type=usm_type, sycl_queue=sycl_queue_normalized)
236+
else:
237+
_start = dpnp.asarray(start, dtype=dt, sycl_queue=sycl_queue_normalized)
238+
if not hasattr(stop, "usm_type"):
239+
_stop = dpnp.asarray(stop, dtype=dt, usm_type=usm_type, sycl_queue=sycl_queue_normalized)
240+
else:
241+
_stop = dpnp.asarray(stop, dtype=dt, sycl_queue=sycl_queue_normalized)
242+
else:
243+
_start = dpnp.asarray(start, dtype=dt, usm_type=usm_type, sycl_queue=sycl_queue_normalized)
244+
_stop = dpnp.asarray(stop, dtype=dt, usm_type=usm_type, sycl_queue=sycl_queue_normalized)
209245

210-
# if result is not empty, then fiil first and last elements
211-
if num > 0:
212-
result.get_pyobj()[0] = start
213-
if endpoint and result.size > 1:
214-
result.get_pyobj()[result.size - 1] = stop
246+
# FIXME: issue #1304. Mathematical operations with scalar don't follow data type.
247+
_num = dpnp.asarray((num - 1) if endpoint else num, dtype=dt, usm_type=usm_type, sycl_queue=sycl_queue_normalized)
248+
249+
step = (_stop - _start) / _num
250+
251+
res = dpnp_container.arange(0,
252+
stop=num,
253+
step=1,
254+
dtype=dt,
255+
usm_type=usm_type,
256+
sycl_queue=sycl_queue_normalized)
257+
258+
res = res.reshape((-1,) + (1,) * step.ndim)
259+
res = res * step + _start
260+
261+
if endpoint and num > 1:
262+
res[-1] = dpnp_container.full(step.shape, _stop)
215263

216-
return (result.get_pyobj(), step)
264+
if numpy.issubdtype(dtype, dpnp.integer):
265+
dpnp.floor(res, out=res)
266+
return res.astype(dtype)
217267

218268

219269
cpdef utils.dpnp_descriptor dpnp_logspace(start, stop, num, endpoint, base, dtype, axis):

dpnp/dpnp_container.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"empty",
4848
"eye",
4949
"full",
50+
"linspace",
5051
"ones"
5152
"tril",
5253
"triu",
@@ -126,6 +127,33 @@ def empty(shape,
126127
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
127128

128129

130+
def eye(N,
131+
M=None,
132+
/,
133+
*,
134+
k=0,
135+
dtype=None,
136+
order="C",
137+
device=None,
138+
usm_type="device",
139+
sycl_queue=None):
140+
"""Validate input parameters before passing them into `dpctl.tensor` module"""
141+
dpu.validate_usm_type(usm_type, allow_none=False)
142+
sycl_queue_normalized = dpnp.get_normalized_queue_device(sycl_queue=sycl_queue, device=device)
143+
if order is None:
144+
order = 'C'
145+
146+
"""Creates `dpnp_array` with ones on the `k`th diagonal."""
147+
array_obj = dpt.eye(N,
148+
M,
149+
k=k,
150+
dtype=dtype,
151+
order=order,
152+
usm_type=usm_type,
153+
sycl_queue=sycl_queue_normalized)
154+
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
155+
156+
129157
def full(shape,
130158
fill_value,
131159
*,
@@ -153,31 +181,29 @@ def full(shape,
153181
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
154182

155183

156-
def eye(N,
157-
M=None,
158-
/,
159-
*,
160-
k=0,
161-
dtype=None,
162-
order="C",
163-
device=None,
164-
usm_type="device",
165-
sycl_queue=None):
184+
def linspace(start,
185+
stop,
186+
/,
187+
num,
188+
*,
189+
dtype=None,
190+
device=None,
191+
usm_type="device",
192+
sycl_queue=None,
193+
endpoint=True):
166194
"""Validate input parameters before passing them into `dpctl.tensor` module"""
167195
dpu.validate_usm_type(usm_type, allow_none=False)
168196
sycl_queue_normalized = dpnp.get_normalized_queue_device(sycl_queue=sycl_queue, device=device)
169-
if order is None:
170-
order = 'C'
171197

172-
"""Creates `dpnp_array` with ones on the `k`th diagonal."""
173-
array_obj = dpt.eye(N,
174-
M,
175-
k=k,
176-
dtype=dtype,
177-
order=order,
178-
usm_type=usm_type,
179-
sycl_queue=sycl_queue_normalized)
180-
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
198+
"""Creates `dpnp_array` with evenly spaced numbers of specified interval."""
199+
array_obj = dpt.linspace(start,
200+
stop,
201+
num,
202+
dtype=dtype,
203+
usm_type=usm_type,
204+
sycl_queue=sycl_queue_normalized,
205+
endpoint=endpoint)
206+
return dpnp_array(array_obj.shape, buffer=array_obj)
181207

182208

183209
def meshgrid(*xi, indexing="xy"):

dpnp/dpnp_iface_arraycreation.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
import dpnp.dpnp_container as dpnp_container
5252
import dpctl.tensor as dpt
53+
import dpctl
5354

5455

5556
__all__ = [
@@ -879,7 +880,18 @@ def identity(n, dtype=None, *, like=None):
879880
return call_origin(numpy.identity, n, dtype=dtype, like=like)
880881

881882

882-
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
883+
def linspace(start,
884+
stop,
885+
/,
886+
num,
887+
*,
888+
dtype=None,
889+
device=None,
890+
usm_type=None,
891+
sycl_queue=None,
892+
endpoint=True,
893+
retstep=False,
894+
axis=0):
883895
"""
884896
Return evenly spaced numbers over a specified interval.
885897
@@ -888,6 +900,8 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
888900
Limitations
889901
-----------
890902
Parameter ``axis`` is supported only with default value ``0``.
903+
Parameter ``retstep`` is supported only with default value ``False``.
904+
Otherwise the function will be executed sequentially on CPU.
891905
892906
See Also
893907
--------
@@ -913,16 +927,19 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
913927
914928
"""
915929

916-
if not use_origin_backend():
917-
if axis != 0:
918-
checker_throw_value_error("linspace", "axis", axis, 0)
919-
920-
res = dpnp_linspace(start, stop, num, endpoint, retstep, dtype, axis)
921-
922-
if retstep:
923-
return res
924-
else:
925-
return res[0]
930+
if retstep is not False:
931+
pass
932+
elif axis != 0:
933+
pass
934+
else:
935+
return dpnp_linspace(start,
936+
stop,
937+
num,
938+
dtype=dtype,
939+
device=device,
940+
usm_type=usm_type,
941+
sycl_queue=sycl_queue,
942+
endpoint=endpoint)
926943

927944
return call_origin(numpy.linspace, start, stop, num, endpoint, retstep, dtype, axis)
928945

tests/skipped_tests.tbl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,10 @@ tests/third_party/cupy/creation_tests/test_ranges.py::TestMgrid::test_mgrid3
418418
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid3
419419
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid4
420420
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid5
421-
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_array_start_stop
422421
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_array_start_stop_axis1
423-
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_float_underflow
424-
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_mixed_start_stop
425-
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_mixed_start_stop2
426-
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_start_stop_list
422+
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_one_num_no_endopoint_with_retstep
423+
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_with_retstep
424+
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_zero_num_no_endopoint_with_retstep
427425
tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_AxisConcatenator_init1
428426
tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_len
429427
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_1

tests/skipped_tests_gpu.tbl

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -594,19 +594,13 @@ tests/third_party/cupy/creation_tests/test_ranges.py::TestMgrid::test_mgrid5
594594
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid3
595595
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid4
596596
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid5
597-
598-
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_array_start_stop
599-
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_array_start_stop_axis1
600597
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_arange_negative_size
601598
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_arange_no_dtype_int
602-
603-
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_float_underflow
604-
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_mixed_start_stop
605-
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_mixed_start_stop2
606-
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_start_stop_list
607-
599+
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_array_start_stop_axis1
600+
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_one_num_no_endopoint_with_retstep
601+
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_with_retstep
602+
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_zero_num_no_endopoint_with_retstep
608603
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_logspace_zero_num
609-
610604
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2
611605
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2
612606
tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2

tests/test_arraycreation.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,51 @@ def test_dpctl_tensor_input(func, args):
513513
assert_array_equal(X, Y)
514514

515515

516+
@pytest.mark.parametrize("start",
517+
[0, -5, 10, -2.5, 9.7],
518+
ids=['0', '-5', '10', '-2.5', '9.7'])
519+
@pytest.mark.parametrize("stop",
520+
[0, 10, -2, 20.5, 1000],
521+
ids=['0', '10', '-2', '20.5', '1000'])
522+
@pytest.mark.parametrize("num",
523+
[5, numpy.array(10), dpnp.array(17), dpt.asarray(100)],
524+
ids=['5', 'numpy.array(10)', 'dpnp.array(17)', 'dpt.asarray(100)'])
525+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_float16=False))
526+
def test_linspace(start, stop, num, dtype):
527+
func = lambda xp: xp.linspace(start, stop, num, dtype=dtype)
528+
529+
if numpy.issubdtype(dtype, dpnp.integer):
530+
assert_allclose(func(numpy), func(dpnp), rtol=1)
531+
else:
532+
assert_allclose(func(numpy), func(dpnp), atol=numpy.finfo(dtype).eps)
533+
534+
535+
@pytest.mark.parametrize("start_dtype",
536+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
537+
ids=['float64', 'float32', 'int64', 'int32'])
538+
@pytest.mark.parametrize("stop_dtype",
539+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
540+
ids=['float64', 'float32', 'int64', 'int32'])
541+
def test_linspace_dtype(start_dtype, stop_dtype):
542+
start = numpy.array([1, 2, 3], dtype=start_dtype)
543+
stop = numpy.array([11, 7, -2], dtype=stop_dtype)
544+
dpnp.linspace(start, stop, 10)
545+
546+
547+
@pytest.mark.parametrize("start",
548+
[dpnp.array(1), dpnp.array([2.6]), numpy.array([[-6.7, 3]]), [1, -4], (3, 5)])
549+
@pytest.mark.parametrize("stop",
550+
[dpnp.array([-4]), dpnp.array([[2.6], [- 4]]), numpy.array(2), [[-4.6]], (3,)])
551+
def test_linspace_arrays(start, stop):
552+
func = lambda xp: xp.linspace(start, stop, 10)
553+
assert func(numpy).shape == func(dpnp).shape
554+
555+
556+
def test_linspace_complex():
557+
func = lambda xp: xp.linspace(0, 3 + 2j, num=1000)
558+
assert_allclose(func(numpy), func(dpnp))
559+
560+
516561
@pytest.mark.parametrize("arrays",
517562
[[], [[1]], [[1, 2, 3], [4, 5, 6]], [[1, 2], [3, 4], [5, 6]]],
518563
ids=['[]', '[[1]]', '[[1, 2, 3], [4, 5, 6]]', '[[1, 2], [3, 4], [5, 6]]'])

tests/test_special.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def test_erf():
77
a = numpy.linspace(2.0, 3.0, num=10)
88
ia = dpnp.linspace(2.0, 3.0, num=10)
99

10-
numpy.testing.assert_array_equal(a, ia)
10+
numpy.testing.assert_allclose(a, ia)
1111

1212
expected = numpy.empty_like(a)
1313
for idx, val in enumerate(a):

0 commit comments

Comments
 (0)