Skip to content

Commit 1255f7f

Browse files
feat(tidy3d): FXC-3961-faster-convolutions-for-tidy-3-d-plugins-autograd-filters
1 parent d39a059 commit 1255f7f

File tree

3 files changed

+173
-12
lines changed

3 files changed

+173
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5454
- Unified run submission API: `web.run(...)` is now a container-aware wrapper that accepts a single simulation or arbitrarily nested containers (`list`, `tuple`, `dict` values) and returns results in the same shape.
5555
- `web.Batch(ComponentModeler)` and `web.Job(ComponentModeler)` native support
5656
- Simulation data of batch jobs are now automatically downloaded upon their individual completion in `Batch.run()`, avoiding waiting for the entire batch to reach completion.
57+
- Improved speed of autograd tracing for convolutions.
5758

5859
### Fixed
5960
- Ensured the legacy `Env` proxy mirrors `config.web` profile switches and preserves API URL.

tests/test_plugins/autograd/test_functions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,52 @@ def test_kernel_array_dimension_mismatch(self):
168168
convolve(self.array, kernel_mismatch)
169169

170170

171+
class TestConvolveAxes:
172+
@pytest.mark.parametrize("mode", ["valid", "same", "full"])
173+
@pytest.mark.parametrize("padding", ["constant", "edge"])
174+
def test_convolve_axes_val(self, rng, mode, padding):
175+
"""Test convolution with explicit axes against NumPy implementations."""
176+
array = rng.random((2, 5))
177+
kernel = rng.random((3, 3))
178+
axes = ([1], [1])
179+
180+
conv_td = convolve(array, kernel, padding=padding, mode=mode, axes=axes)
181+
182+
working_array = array
183+
scipy_mode = mode
184+
if mode in ("same", "full"):
185+
pad_width = kernel.shape[1] // 2
186+
working_array = pad(array, (pad_width, pad_width), mode=padding, axis=1)
187+
scipy_mode = "valid" if mode == "same" else mode
188+
189+
working_array_np = np.asarray(working_array)
190+
kernel_np = np.asarray(kernel)
191+
conv_length = np.convolve(working_array_np[0], kernel_np[0], mode=scipy_mode).shape[0]
192+
193+
expected = np.empty((array.shape[0], kernel.shape[0], conv_length))
194+
for i in range(array.shape[0]):
195+
for j in range(kernel.shape[0]):
196+
expected[i, j] = np.convolve(
197+
working_array_np[i],
198+
kernel_np[j],
199+
mode=scipy_mode,
200+
)
201+
202+
npt.assert_allclose(conv_td, expected, atol=1e-12)
203+
204+
def test_convolve_axes_grad(self, rng):
205+
"""Test gradients of convolution when specific axes are provided."""
206+
array = rng.random((2, 5))
207+
kernel = rng.random((3, 3))
208+
check_grads(convolve, modes=["rev"], order=2)(
209+
array,
210+
kernel,
211+
padding="constant",
212+
mode="valid",
213+
axes=([1], [1]),
214+
)
215+
216+
171217
@pytest.mark.parametrize(
172218
"op,sp_op",
173219
[

tidy3d/plugins/autograd/functions.py

Lines changed: 126 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as onp
88
from autograd import jacobian
99
from autograd.extend import defvjp, primitive
10-
from autograd.scipy.signal import convolve as convolve_ag
10+
from autograd.numpy.fft import fftn, ifftn
1111
from autograd.scipy.special import logsumexp
1212
from autograd.tracer import getval
1313
from numpy.lib.stride_tricks import sliding_window_view
@@ -37,6 +37,118 @@
3737
]
3838

3939

40+
def _normalize_axes(
41+
ndim_array: int,
42+
ndim_kernel: int,
43+
axes: Union[tuple[Iterable[int], Iterable[int]], None],
44+
) -> tuple[tuple[int, ...], tuple[int, ...]]:
45+
"""Normalize the axes specification for convolution."""
46+
47+
if axes is None:
48+
if ndim_array != ndim_kernel:
49+
raise ValueError(
50+
"Kernel dimensions must match array dimensions when 'axes' is not provided, "
51+
f"got array ndim {ndim_array} and kernel ndim {ndim_kernel}."
52+
)
53+
axes_array = tuple(range(ndim_array))
54+
axes_kernel = tuple(range(ndim_kernel))
55+
return axes_array, axes_kernel
56+
57+
if len(axes) != 2:
58+
raise ValueError("'axes' must be a tuple of two iterable collections of axis indices.")
59+
60+
axes_array_raw, axes_kernel_raw = axes
61+
62+
axes_array = tuple((ax + ndim_array) % ndim_array for ax in axes_array_raw)
63+
axes_kernel = tuple((ax + ndim_kernel) % ndim_kernel for ax in axes_kernel_raw)
64+
65+
if len(axes_array) != len(axes_kernel):
66+
raise ValueError(
67+
"The number of convolution axes for the array and kernel must be the same, "
68+
f"got {len(axes_array)} and {len(axes_kernel)}."
69+
)
70+
71+
if len(set(axes_array)) != len(axes_array) or len(set(axes_kernel)) != len(axes_kernel):
72+
raise ValueError("Convolution axes must be unique for both the array and the kernel.")
73+
74+
return axes_array, axes_kernel
75+
76+
77+
def _fft_convolve_general(
78+
array: NDArray,
79+
kernel: NDArray,
80+
axes_array: tuple[int, ...],
81+
axes_kernel: tuple[int, ...],
82+
mode: Literal["full", "valid"],
83+
) -> NDArray:
84+
"""Perform convolution using FFT along the specified axes."""
85+
86+
num_conv_axes = len(axes_array)
87+
88+
if num_conv_axes == 0:
89+
array_shape = array.shape
90+
kernel_shape = kernel.shape
91+
result = np.multiply(
92+
array.reshape(array_shape + (1,) * kernel.ndim),
93+
kernel.reshape((1,) * array.ndim + kernel_shape),
94+
)
95+
return result.reshape(array_shape + kernel_shape)
96+
97+
ignore_axes_array = tuple(ax for ax in range(array.ndim) if ax not in axes_array)
98+
ignore_axes_kernel = tuple(ax for ax in range(kernel.ndim) if ax not in axes_kernel)
99+
100+
new_order_array = ignore_axes_array + axes_array
101+
new_order_kernel = ignore_axes_kernel + axes_kernel
102+
103+
array_reordered = np.transpose(array, new_order_array) if array.ndim else array
104+
kernel_reordered = np.transpose(kernel, new_order_kernel) if kernel.ndim else kernel
105+
106+
num_batch_array = len(ignore_axes_array)
107+
num_batch_kernel = len(ignore_axes_kernel)
108+
109+
array_batch_shape = array_reordered.shape[:num_batch_array]
110+
kernel_batch_shape = kernel_reordered.shape[:num_batch_kernel]
111+
112+
array_conv_shape = array_reordered.shape[num_batch_array:]
113+
kernel_conv_shape = kernel_reordered.shape[num_batch_kernel:]
114+
115+
array_expand_shape = array_batch_shape + (1,) * num_batch_kernel + array_conv_shape
116+
kernel_expand_shape = (1,) * num_batch_array + kernel_batch_shape + kernel_conv_shape
117+
118+
array_expanded = np.reshape(array_reordered, array_expand_shape)
119+
kernel_expanded = np.reshape(kernel_reordered, kernel_expand_shape)
120+
121+
fft_axes = tuple(range(-num_conv_axes, 0))
122+
fft_shape = tuple(
123+
int(array_dim + kernel_dim - 1)
124+
for array_dim, kernel_dim in zip(array_conv_shape, kernel_conv_shape)
125+
)
126+
127+
array_fft = fftn(array_expanded, fft_shape, axes=fft_axes)
128+
kernel_fft = fftn(kernel_expanded, fft_shape, axes=fft_axes)
129+
full_result = ifftn(array_fft * kernel_fft, fft_shape, axes=fft_axes)
130+
131+
if mode == "full":
132+
result = full_result
133+
elif mode == "valid":
134+
valid_slices = [slice(None)] * full_result.ndim
135+
for axis_offset, (array_dim, kernel_dim) in enumerate(
136+
zip(array_conv_shape, kernel_conv_shape)
137+
):
138+
start = int(min(array_dim, kernel_dim) - 1)
139+
length = int(abs(array_dim - kernel_dim) + 1)
140+
axis = full_result.ndim - num_conv_axes + axis_offset
141+
valid_slices[axis] = slice(start, start + length)
142+
result = full_result[tuple(valid_slices)]
143+
else:
144+
raise ValueError(f"Unsupported convolution mode '{mode}'.")
145+
146+
if not np.iscomplexobj(array) and not np.iscomplexobj(kernel):
147+
result = np.real(result)
148+
149+
return result
150+
151+
40152
def _get_pad_indices(
41153
n: int,
42154
pad_width: tuple[int, int],
@@ -189,19 +301,21 @@ def convolve(
189301
if any(k % 2 == 0 for k in kernel.shape):
190302
raise ValueError(f"All kernel dimensions must be odd, got {kernel.shape}.")
191303

192-
if kernel.ndim != array.ndim and axes is None:
193-
raise ValueError(
194-
f"Kernel dimensions must match array dimensions, got kernel {kernel.shape} and array {array.shape}."
195-
)
304+
axes_array, axes_kernel = _normalize_axes(array.ndim, kernel.ndim, axes)
196305

197-
if mode in ("same", "full"):
198-
kernel_dims = kernel.shape if axes is None else [kernel.shape[d] for d in axes[1]]
199-
pad_widths = [(ks // 2, ks // 2) for ks in kernel_dims]
200-
for axis, pad_width in enumerate(pad_widths):
201-
array = pad(array, pad_width, mode=padding, axis=axis)
202-
mode = "valid" if mode == "same" else mode
306+
working_array = array
307+
effective_mode = mode
203308

204-
return convolve_ag(array, kernel, axes=axes, mode=mode)
309+
if mode in ("same", "full"):
310+
for ax_array, ax_kernel in zip(axes_array, axes_kernel):
311+
pad_width = kernel.shape[ax_kernel] // 2
312+
if pad_width > 0:
313+
working_array = pad(
314+
working_array, (pad_width, pad_width), mode=padding, axis=ax_array
315+
)
316+
effective_mode = "valid" if mode == "same" else mode
317+
318+
return _fft_convolve_general(working_array, kernel, axes_array, axes_kernel, effective_mode)
205319

206320

207321
def _get_footprint(size, structure, maxval):

0 commit comments

Comments
 (0)