@@ -107,8 +107,8 @@ def set_print_options(
107
107
):
108
108
"""
109
109
set_print_options(linewidth=None, edgeitems=None, threshold=None,
110
- precision=None, floatmode=None, suppress=None, nanstr =None,
111
- infstr=None, sign=None, numpy=False)
110
+ precision=None, floatmode=None, suppress=None,
111
+ nanstr=None, infstr=None, sign=None, numpy=False)
112
112
113
113
Set options for printing ``dpctl.tensor.usm_ndarray`` class.
114
114
@@ -238,7 +238,7 @@ def _nd_corners(x, edge_items, slices=()):
238
238
return _nd_corners (x , edge_items , slices + (slice (None , None , None ),))
239
239
240
240
241
- def _usm_ndarray_str (
241
+ def usm_ndarray_str (
242
242
x ,
243
243
line_width = None ,
244
244
edge_items = None ,
@@ -252,6 +252,72 @@ def _usm_ndarray_str(
252
252
prefix = "" ,
253
253
suffix = "" ,
254
254
):
255
+ """
256
+ usm_ndarray_str(x, line_width=None, edgeitems=None, threshold=None,
257
+ precision=None, floatmode=None, suppress=None,
258
+ sign=None, numpy=False, separator=" ", prefix="",
259
+ suffix="") -> str
260
+
261
+ Returns a string representing the elements of a
262
+ ``dpctl.tensor.usm_ndarray``.
263
+
264
+ Args:
265
+ x (usm_ndarray): Input array.
266
+ line_width (int, optional): Number of characters printed per line.
267
+ Raises `TypeError` if line_width is not an integer.
268
+ Default: `75`.
269
+ edgeitems (int, optional): Number of elements at the beginning and end
270
+ when the printed array is abbreviated.
271
+ Raises `TypeError` if edgeitems is not an integer.
272
+ Default: `3`.
273
+ threshold (int, optional): Number of elements that triggers array
274
+ abbreviation.
275
+ Raises `TypeError` if threshold is not an integer.
276
+ Default: `1000`.
277
+ precision (int or None, optional): Number of digits printed for
278
+ floating point numbers.
279
+ Raises `TypeError` if precision is not an integer.
280
+ Default: `8`.
281
+ floatmode (str, optional): Controls how floating point
282
+ numbers are interpreted.
283
+
284
+ `"fixed:`: Always prints exactly `precision` digits.
285
+ `"unique"`: Ignores precision, prints the number of
286
+ digits necessary to uniquely specify each number.
287
+ `"maxprec"`: Prints `precision` digits or fewer,
288
+ if fewer will uniquely represent a number.
289
+ `"maxprec_equal"`: Prints an equal number of digits
290
+ for each number. This number is `precision` digits or fewer,
291
+ if fewer will uniquely represent each number.
292
+ Raises `ValueError` if floatmode is not one of
293
+ `fixed`, `unique`, `maxprec`, or `maxprec_equal`.
294
+ Default: "maxprec_equal"
295
+ suppress (bool, optional): If `True,` numbers equal to zero
296
+ in the current precision will print as zero.
297
+ Default: `False`.
298
+ sign (str, optional): Controls the sign of floating point
299
+ numbers.
300
+ `"-"`: Omit the sign of positive numbers.
301
+ `"+"`: Always print the sign of positive numbers.
302
+ `" "`: Always print a whitespace in place of the
303
+ sign of positive numbers.
304
+ Raises `ValueError` if sign is not one of
305
+ `"-"`, `"+"`, or `" "`.
306
+ Default: `"-"`.
307
+ numpy (bool, optional): If `True,` then before other specified print
308
+ options are set, a dictionary of Numpy's print options
309
+ will be used to initialize dpctl's print options.
310
+ Default: "False"
311
+ separator (str, optional): String inserted between elements of
312
+ the array string.
313
+ Default: " "
314
+ prefix (str, optional): String used to determine spacing to the left
315
+ of the array string.
316
+ Default: ""
317
+ suffix (str, optional): String that determines length of the last line
318
+ of the array string.
319
+ Default: ""
320
+ """
255
321
if not isinstance (x , dpt .usm_ndarray ):
256
322
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
257
323
@@ -285,7 +351,33 @@ def _usm_ndarray_str(
285
351
return s
286
352
287
353
288
- def _usm_ndarray_repr (x , line_width = None , precision = None , suppress = None ):
354
+ def usm_ndarray_repr (
355
+ x , line_width = None , precision = None , suppress = None , prefix = "usm_ndarray"
356
+ ):
357
+ """
358
+ usm_ndarray_repr(x, line_width=None, precision=None,
359
+ suppress=None, prefix="") -> str
360
+
361
+ Returns a formatted string representing the elements
362
+ of a ``dpctl.tensor.usm_ndarray`` and its data type,
363
+ if not a default type.
364
+
365
+ Args:
366
+ x (usm_ndarray): Input array.
367
+ line_width (int, optional): Number of characters printed per line.
368
+ Raises `TypeError` if line_width is not an integer.
369
+ Default: `75`.
370
+ precision (int or None, optional): Number of digits printed for
371
+ floating point numbers.
372
+ Raises `TypeError` if precision is not an integer.
373
+ Default: `8`.
374
+ suppress (bool, optional): If `True,` numbers equal to zero
375
+ in the current precision will print as zero.
376
+ Default: `False`.
377
+ prefix (str, optional): String inserted at the start of the array
378
+ string.
379
+ Default: ""
380
+ """
289
381
if not isinstance (x , dpt .usm_ndarray ):
290
382
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
291
383
@@ -299,10 +391,10 @@ def _usm_ndarray_repr(x, line_width=None, precision=None, suppress=None):
299
391
dpt .complex128 ,
300
392
]
301
393
302
- prefix = "usm_ndarray ("
394
+ prefix = prefix + " ("
303
395
suffix = ")"
304
396
305
- s = _usm_ndarray_str (
397
+ s = usm_ndarray_str (
306
398
x ,
307
399
line_width = line_width ,
308
400
precision = precision ,
0 commit comments