@@ -66,24 +66,15 @@ def _get_pad_indices(
6666 if n == 0 :
6767 return numpy_module .zeros (total_pad , dtype = int )
6868
69- idx = numpy_module .arange (- pad_width [0 ], n + pad_width [1 ])
70-
69+ pad_left , pad_right = pad_width
7170 if mode == "constant" :
72- return idx
73- if mode == "edge" :
74- return numpy_module .clip (idx , 0 , n - 1 )
75- if mode == "reflect" :
76- period = 2 * n - 2 if n > 1 else 1
77- idx = numpy_module .mod (idx , period )
78- return numpy_module .where (idx >= n , period - idx , idx )
79- if mode == "symmetric" :
80- period = 2 * n if n > 1 else 1
81- idx = numpy_module .mod (idx , period )
82- return numpy_module .where (idx >= n , period - idx - 1 , idx )
83- if mode == "wrap" :
84- return numpy_module .mod (idx , n )
85-
86- raise ValueError (f"Unsupported padding mode: { mode } " )
71+ return numpy_module .arange (- pad_left , n + pad_right )
72+
73+ try :
74+ indices = onp .pad (onp .arange (n ), (pad_left , pad_right ), mode = mode )
75+ except ValueError as error :
76+ raise ValueError (f"Unsupported padding mode: { mode } " ) from error
77+ return numpy_module .asarray (indices , dtype = int )
8778
8879
8980def pad (
0 commit comments