Skip to content

Commit d5620b8

Browse files
committed
usm_ndarray repr and str exposed in dpctl.tensor
1 parent 8c3f3d5 commit d5620b8

File tree

4 files changed

+179
-9
lines changed

4 files changed

+179
-9
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
get_print_options,
6363
print_options,
6464
set_print_options,
65+
usm_ndarray_repr,
66+
usm_ndarray_str,
6567
)
6668
from dpctl.tensor._reshape import reshape
6769
from dpctl.tensor._usmarray import usm_ndarray
@@ -137,4 +139,6 @@
137139
"get_print_options",
138140
"set_print_options",
139141
"print_options",
142+
"usm_ndarray_repr",
143+
"usm_ndarray_str",
140144
]

dpctl/tensor/_print.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def set_print_options(
107107
):
108108
"""
109109
set_print_options(linewidth=None, edgeitems=None, threshold=None,
110-
precision=None, floatmode=None, suppress=None, nanstr=None,
111-
infstr=None, sign=None, numpy=False)
110+
precision=None, floatmode=None, suppress=None,
111+
nanstr=None, infstr=None, sign=None, numpy=False)
112112
113113
Set options for printing ``dpctl.tensor.usm_ndarray`` class.
114114
@@ -238,7 +238,7 @@ def _nd_corners(x, edge_items, slices=()):
238238
return _nd_corners(x, edge_items, slices + (slice(None, None, None),))
239239

240240

241-
def _usm_ndarray_str(
241+
def usm_ndarray_str(
242242
x,
243243
line_width=None,
244244
edge_items=None,
@@ -252,6 +252,72 @@ def _usm_ndarray_str(
252252
prefix="",
253253
suffix="",
254254
):
255+
"""
256+
usm_ndarray_str(x, line_width=None, edgeitems=None, threshold=None,
257+
precision=None, floatmode=None, suppress=None,
258+
sign=None, numpy=False, separator=" ", prefix="",
259+
suffix="") -> str
260+
261+
Returns a string representing the elements of a
262+
``dpctl.tensor.usm_ndarray``.
263+
264+
Args:
265+
x (usm_ndarray): Input array.
266+
line_width (int, optional): Number of characters printed per line.
267+
Raises `TypeError` if line_width is not an integer.
268+
Default: `75`.
269+
edgeitems (int, optional): Number of elements at the beginning and end
270+
when the printed array is abbreviated.
271+
Raises `TypeError` if edgeitems is not an integer.
272+
Default: `3`.
273+
threshold (int, optional): Number of elements that triggers array
274+
abbreviation.
275+
Raises `TypeError` if threshold is not an integer.
276+
Default: `1000`.
277+
precision (int or None, optional): Number of digits printed for
278+
floating point numbers.
279+
Raises `TypeError` if precision is not an integer.
280+
Default: `8`.
281+
floatmode (str, optional): Controls how floating point
282+
numbers are interpreted.
283+
284+
`"fixed:`: Always prints exactly `precision` digits.
285+
`"unique"`: Ignores precision, prints the number of
286+
digits necessary to uniquely specify each number.
287+
`"maxprec"`: Prints `precision` digits or fewer,
288+
if fewer will uniquely represent a number.
289+
`"maxprec_equal"`: Prints an equal number of digits
290+
for each number. This number is `precision` digits or fewer,
291+
if fewer will uniquely represent each number.
292+
Raises `ValueError` if floatmode is not one of
293+
`fixed`, `unique`, `maxprec`, or `maxprec_equal`.
294+
Default: "maxprec_equal"
295+
suppress (bool, optional): If `True,` numbers equal to zero
296+
in the current precision will print as zero.
297+
Default: `False`.
298+
sign (str, optional): Controls the sign of floating point
299+
numbers.
300+
`"-"`: Omit the sign of positive numbers.
301+
`"+"`: Always print the sign of positive numbers.
302+
`" "`: Always print a whitespace in place of the
303+
sign of positive numbers.
304+
Raises `ValueError` if sign is not one of
305+
`"-"`, `"+"`, or `" "`.
306+
Default: `"-"`.
307+
numpy (bool, optional): If `True,` then before other specified print
308+
options are set, a dictionary of Numpy's print options
309+
will be used to initialize dpctl's print options.
310+
Default: "False"
311+
separator (str, optional): String inserted between elements of
312+
the array string.
313+
Default: " "
314+
prefix (str, optional): String used to determine spacing to the left
315+
of the array string.
316+
Default: ""
317+
suffix (str, optional): String that determines length of the last line
318+
of the array string.
319+
Default: ""
320+
"""
255321
if not isinstance(x, dpt.usm_ndarray):
256322
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
257323

@@ -285,7 +351,33 @@ def _usm_ndarray_str(
285351
return s
286352

287353

288-
def _usm_ndarray_repr(x, line_width=None, precision=None, suppress=None):
354+
def usm_ndarray_repr(
355+
x, line_width=None, precision=None, suppress=None, prefix="usm_ndarray"
356+
):
357+
"""
358+
usm_ndarray_repr(x, line_width=None, precision=None,
359+
suppress=None, prefix="") -> str
360+
361+
Returns a formatted string representing the elements
362+
of a ``dpctl.tensor.usm_ndarray`` and its data type,
363+
if not a default type.
364+
365+
Args:
366+
x (usm_ndarray): Input array.
367+
line_width (int, optional): Number of characters printed per line.
368+
Raises `TypeError` if line_width is not an integer.
369+
Default: `75`.
370+
precision (int or None, optional): Number of digits printed for
371+
floating point numbers.
372+
Raises `TypeError` if precision is not an integer.
373+
Default: `8`.
374+
suppress (bool, optional): If `True,` numbers equal to zero
375+
in the current precision will print as zero.
376+
Default: `False`.
377+
prefix (str, optional): String inserted at the start of the array
378+
string.
379+
Default: ""
380+
"""
289381
if not isinstance(x, dpt.usm_ndarray):
290382
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
291383

@@ -299,10 +391,10 @@ def _usm_ndarray_repr(x, line_width=None, precision=None, suppress=None):
299391
dpt.complex128,
300392
]
301393

302-
prefix = "usm_ndarray("
394+
prefix = prefix + "("
303395
suffix = ")"
304396

305-
s = _usm_ndarray_str(
397+
s = usm_ndarray_str(
306398
x,
307399
line_width=line_width,
308400
precision=precision,

dpctl/tensor/_usmarray.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import dpctl
2626
import dpctl.memory as dpmem
2727

2828
from ._device import Device
29-
from ._print import _usm_ndarray_repr, _usm_ndarray_str
29+
from ._print import usm_ndarray_repr, usm_ndarray_str
3030

3131
from cpython.mem cimport PyMem_Free
3232
from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
@@ -1145,10 +1145,10 @@ cdef class usm_ndarray:
11451145
return self
11461146

11471147
def __str__(self):
1148-
return _usm_ndarray_str(self)
1148+
return usm_ndarray_str(self)
11491149

11501150
def __repr__(self):
1151-
return _usm_ndarray_repr(self)
1151+
return usm_ndarray_repr(self)
11521152

11531153

11541154
cdef usm_ndarray _real_view(usm_ndarray ary):

dpctl/tests/test_usm_ndarray_print.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,57 @@ def test_print_option_arg_validation(self, arg, err):
4848
with pytest.raises(err):
4949
dpt.set_print_options(**arg)
5050

51+
def test_usm_ndarray_repr_arg_validation(self):
52+
X = dict()
53+
with pytest.raises(TypeError):
54+
dpt.usm_ndarray_repr(X)
55+
56+
X = dpt.arange(4)
57+
with pytest.raises(TypeError):
58+
dpt.usm_ndarray_repr(X, line_width="I")
59+
60+
with pytest.raises(TypeError):
61+
dpt.usm_ndarray_repr(X, precision="I")
62+
63+
with pytest.raises(TypeError):
64+
dpt.usm_ndarray_repr(X, prefix=4)
65+
66+
def test_usm_ndarray_str_arg_validation(self):
67+
X = dict()
68+
with pytest.raises(TypeError):
69+
dpt.usm_ndarray_str(X)
70+
71+
X = dpt.arange(4)
72+
with pytest.raises(TypeError):
73+
dpt.usm_ndarray_str(X, line_width="I")
74+
75+
with pytest.raises(TypeError):
76+
dpt.usm_ndarray_str(X, edge_items="I")
77+
78+
with pytest.raises(TypeError):
79+
dpt.usm_ndarray_str(X, threshold="I")
80+
81+
with pytest.raises(TypeError):
82+
dpt.usm_ndarray_str(X, precision="I")
83+
84+
with pytest.raises(ValueError):
85+
dpt.usm_ndarray_str(X, floatmode="I")
86+
87+
with pytest.raises(TypeError):
88+
dpt.usm_ndarray_str(X, edge_items="I")
89+
90+
with pytest.raises(ValueError):
91+
dpt.usm_ndarray_str(X, sign="I")
92+
93+
with pytest.raises(TypeError):
94+
dpt.usm_ndarray_str(X, prefix=4)
95+
96+
with pytest.raises(TypeError):
97+
dpt.usm_ndarray_str(X, prefix=4)
98+
99+
with pytest.raises(TypeError):
100+
dpt.usm_ndarray_str(X, suffix=4)
101+
51102

52103
class TestSetPrintOptions(TestPrint):
53104
def test_set_linewidth(self):
@@ -188,6 +239,16 @@ def test_print_str_abbreviated(self):
188239
x = dpt.reshape(x, (3, 3))
189240
assert str(x) == "[[0 ... 2]\n ...\n [6 ... 8]]"
190241

242+
def test_usm_ndarray_str_separator(self):
243+
q = get_queue_or_skip()
244+
245+
x = dpt.reshape(dpt.arange(4, sycl_queue=q), (2, 2))
246+
247+
np.testing.assert_equal(
248+
dpt.usm_ndarray_str(x, prefix="test", separator=" "),
249+
"[[0 1]\n [2 3]]",
250+
)
251+
191252
def test_print_repr(self):
192253
q = get_queue_or_skip()
193254

@@ -282,6 +343,19 @@ def test_repr_appended_dtype(self, dtype):
282343
x = dpt.empty(4, dtype=dtype)
283344
assert repr(x).split("=")[-1][:-1] == x.dtype.name
284345

346+
def test_usm_ndarray_repr_prefix(self):
347+
q = get_queue_or_skip()
348+
349+
x = dpt.arange(4, sycl_queue=q)
350+
np.testing.assert_equal(
351+
dpt.usm_ndarray_repr(x, prefix="test"), "test([0, 1, 2, 3])"
352+
)
353+
x = dpt.reshape(x, (2, 2))
354+
np.testing.assert_equal(
355+
dpt.usm_ndarray_repr(x, prefix="test"),
356+
"test([[0, 1]," "\n [2, 3]])",
357+
)
358+
285359

286360
class TestContextManager:
287361
def test_context_manager_basic(self):

0 commit comments

Comments
 (0)