Skip to content

Commit c8217a8

Browse files
committed
refactor: simplify autograd padding indices
1 parent c73a140 commit c8217a8

File tree

1 file changed

+8
-17
lines changed

1 file changed

+8
-17
lines changed

tidy3d/plugins/autograd/functions.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,24 +66,15 @@ def _get_pad_indices(
6666
if n == 0:
6767
return numpy_module.zeros(total_pad, dtype=int)
6868

69-
idx = numpy_module.arange(-pad_width[0], n + pad_width[1])
70-
69+
pad_left, pad_right = pad_width
7170
if mode == "constant":
72-
return idx
73-
if mode == "edge":
74-
return numpy_module.clip(idx, 0, n - 1)
75-
if mode == "reflect":
76-
period = 2 * n - 2 if n > 1 else 1
77-
idx = numpy_module.mod(idx, period)
78-
return numpy_module.where(idx >= n, period - idx, idx)
79-
if mode == "symmetric":
80-
period = 2 * n if n > 1 else 1
81-
idx = numpy_module.mod(idx, period)
82-
return numpy_module.where(idx >= n, period - idx - 1, idx)
83-
if mode == "wrap":
84-
return numpy_module.mod(idx, n)
85-
86-
raise ValueError(f"Unsupported padding mode: {mode}")
71+
return numpy_module.arange(-pad_left, n + pad_right)
72+
73+
try:
74+
indices = onp.pad(onp.arange(n), (pad_left, pad_right), mode=mode)
75+
except ValueError as error:
76+
raise ValueError(f"Unsupported padding mode: {mode}") from error
77+
return numpy_module.asarray(indices, dtype=int)
8778

8879

8980
def pad(

0 commit comments

Comments
 (0)