Skip to content

Commit 23b3311

Browse files
committed
Fixed platform-specific print test failures
1 parent b173cb6 commit 23b3311

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

dpctl/tensor/_print.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,7 @@ def set_print_options(
125125
Raises `TypeError` if threshold is not an integer.
126126
Default: `1000`.
127127
precision (int or None, optional): Number of digits printed for
128-
floating point numbers. If `floatmode` is not `"fixed",`
129-
`precision` may be `None` to print each float with as many
130-
digits as necessary to produce a unique output.
128+
floating point numbers.
131129
Raises `TypeError` if precision is not an integer.
132130
Default: `8`.
133131
floatmode (str, optional): Controls how floating point

dpctl/tests/test_usm_ndarray_print.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,13 @@ def test_print_str_abbreviated(self):
191191
def test_print_repr(self):
192192
q = get_queue_or_skip()
193193

194-
x = dpt.asarray(0, sycl_queue=q)
194+
x = dpt.asarray(0, dtype="int64", sycl_queue=q)
195195
assert repr(x) == "usm_ndarray(0)"
196196

197197
x = dpt.asarray([np.nan, np.inf], sycl_queue=q)
198198
assert repr(x) == "usm_ndarray([nan, inf])"
199199

200-
x = dpt.arange(9, sycl_queue=q)
200+
x = dpt.arange(9, sycl_queue=q, dtype="int64")
201201
assert repr(x) == "usm_ndarray([0, 1, 2, 3, 4, 5, 6, 7, 8])"
202202

203203
x = dpt.reshape(x, (3, 3))
@@ -208,18 +208,18 @@ def test_print_repr(self):
208208
"\n [6, 7, 8]])",
209209
)
210210

211-
x = dpt.arange(4, dtype="f2", sycl_queue=q)
212-
assert repr(x) == "usm_ndarray([0., 1., 2., 3.], dtype=float16)"
211+
x = dpt.arange(4, dtype="i4", sycl_queue=q)
212+
assert repr(x) == "usm_ndarray([0, 1, 2, 3], dtype=int32)"
213213

214214
def test_print_repr_abbreviated(self):
215215
q = get_queue_or_skip()
216216

217217
dpt.set_print_options(threshold=0, edgeitems=1)
218-
x = dpt.arange(9, sycl_queue=q)
218+
x = dpt.arange(9, dtype="int64", sycl_queue=q)
219219
assert repr(x) == "usm_ndarray([0, ..., 8])"
220220

221-
y = dpt.asarray(x, dtype="f2", copy=True)
222-
assert repr(y) == "usm_ndarray([0., ..., 8.], dtype=float16)"
221+
y = dpt.asarray(x, dtype="i4", copy=True)
222+
assert repr(y) == "usm_ndarray([0, ..., 8], dtype=int32)"
223223

224224
x = dpt.reshape(x, (3, 3))
225225
np.testing.assert_equal(
@@ -232,9 +232,9 @@ def test_print_repr_abbreviated(self):
232232
y = dpt.reshape(y, (3, 3))
233233
np.testing.assert_equal(
234234
repr(y),
235-
"usm_ndarray([[0., ..., 2.],"
235+
"usm_ndarray([[0, ..., 2],"
236236
"\n ...,"
237-
"\n [6., ..., 8.]], dtype=float16)",
237+
"\n [6, ..., 8]], dtype=int32)",
238238
)
239239

240240
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)