@@ -407,6 +407,43 @@ def _concat_diff_input(arr, axis, prepend, append):
407407
408408
409409def diff (x , / , * , axis = - 1 , n = 1 , prepend = None , append = None ):
410+ """
411+ Calculates the `n`-th discrete forward difference of `x` along `axis`.
412+
413+ Args:
414+ x (usm_ndarray):
415+ input array.
416+ axis (int):
417+ axis along which to compute the difference. A valid axis must be on
418+ the interval `[-N, N)`, where `N` is the rank (number of
419+ dimensions) of `x`.
420+ Default: `-1`
421+ n (int):
422+ number of times to recursively compute the difference.
423+ Default: `1`.
424+ prepend (Union[usm_ndarray, bool, int, float, complex]):
425+ value or values to prepend to the specified axis before taking the
426+ difference.
427+ Must have the same shape as `x` except along `axis`, which can have
428+ any shape.
429+ Default: `None`.
430+ append (Union[usm_ndarray, bool, int, float, complex]):
431+ value or values to append to the specified axis before taking the
432+ difference.
433+ Must have the same shape as `x` except along `axis`, which can have
434+ any shape.
435+ Default: `None`.
436+
437+ Returns:
438+ usm_ndarray:
439+ an array containing the `n`-th differences. The array will have the
440+ same shape as `x`, except along `axis`, which will have shape
441+
442+ - prepend.shape[axis] + x.shape[axis] + append.shape[axis] - n
443+
444+ The data type of the returned array is determined by the Type
445+ Promotion Rules.
446+ """
410447
411448 if not isinstance (x , dpt .usm_ndarray ):
412449 raise TypeError (
@@ -417,7 +454,8 @@ def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
417454 n = operator .index (n )
418455
419456 arr = _concat_diff_input (x , axis , prepend , append )
420-
457+ if n == 0 :
458+ return arr
421459 # form slices and recurse
422460 sl0 = tuple (
423461 slice (None ) if i != axis else slice (1 , None ) for i in range (x_nd )
0 commit comments