Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `web.Batch(ComponentModeler)` and `web.Job(ComponentModeler)` native support
- Simulation data of batch jobs are now automatically downloaded upon their individual completion in `Batch.run()`, avoiding waiting for the entire batch to reach completion.
- Port names in `ModalComponentModeler` and `TerminalComponentModeler` can no longer include the `@` symbol.
- Improved speed of convolutions for large inputs.

### Fixed
- Ensured the legacy `Env` proxy mirrors `config.web` profile switches and preserves API URL.
Expand Down
136 changes: 121 additions & 15 deletions tests/test_plugins/autograd/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
threshold,
trapz,
)
from tidy3d.plugins.autograd.functions import _normalize_axes
from tidy3d.plugins.autograd.types import PaddingType

_mode_to_scipy = {
Expand All @@ -38,6 +39,15 @@
"wrap": "wrap",
}

CONV_MODES = ["full", "same", "valid"]

_CONVOLVE_AXES_CASES = [
([0], [0]),
([1], [1]),
([1], [0]),
([-1], [-1]),
]


@pytest.mark.parametrize("mode", PaddingType.__args__)
@pytest.mark.parametrize("size", [3, 4, (3, 3), (4, 4), (3, 4), (3, 3, 3), (4, 4, 4), (3, 4, 5)])
Expand Down Expand Up @@ -94,7 +104,7 @@ def test_negative_axis_out_of_range(self):
pad(self.array, (1, 1), axis=-3)


@pytest.mark.parametrize("mode", ["full", "valid", "same"])
@pytest.mark.parametrize("mode", CONV_MODES)
@pytest.mark.parametrize("padding", PaddingType.__args__)
@pytest.mark.parametrize(
"ary_size", [7, 8, (7, 7), (8, 8), (7, 8), (7, 7, 7), (8, 8, 8), (7, 8, 9)]
Expand All @@ -117,22 +127,10 @@ def test_convolve_val(self, rng, mode, padding, ary_size, kernel_size, square_ke
"""Test convolution values against SciPy for various modes, padding, array sizes, and kernel sizes."""
x, k = self._ary_and_kernel(rng, ary_size, kernel_size, square_kernel)

if mode in ("full", "same"):
pad_widths = [(k // 2, k // 2) for k in k.shape]
x_padded = x
for axis, pad_width in enumerate(pad_widths):
x_padded = pad(x_padded, pad_width, mode=padding, axis=axis)
conv_sp = convolve_sp(x_padded, k, mode="valid" if mode == "same" else mode)
else:
conv_sp = convolve_sp(x, k, mode=mode)

conv_td = convolve(x, k, padding=padding, mode=mode)
conv_sp = _reference_convolution(x, k, mode, padding, axes=None)

npt.assert_allclose(
conv_td,
conv_sp,
atol=1e-12, # scipy's "full" somehow is not zero at the edges...
)
npt.assert_allclose(conv_td, conv_sp, atol=1e-12)

def test_convolve_grad(self, rng, mode, padding, ary_size, kernel_size, square_kernel):
"""Test gradients of convolution function for various modes, padding, array sizes, and kernel sizes."""
Expand Down Expand Up @@ -168,6 +166,114 @@ def test_kernel_array_dimension_mismatch(self):
convolve(self.array, kernel_mismatch)


def _reference_convolve_with_axes(array, kernel, axes_array, axes_kernel, mode):
"""Construct a SciPy reference for convolutions with explicit axes."""

array_batch_axes = tuple(ax for ax in range(array.ndim) if ax not in axes_array)
kernel_batch_axes = tuple(ax for ax in range(kernel.ndim) if ax not in axes_kernel)

array_perm = array_batch_axes + axes_array
kernel_perm = kernel_batch_axes + axes_kernel

array_reordered = np.transpose(array, array_perm)
kernel_reordered = np.transpose(kernel, kernel_perm)

len_array_batch = len(array_batch_axes)
len_kernel_batch = len(kernel_batch_axes)

array_batch_shape = array_reordered.shape[:len_array_batch]
kernel_batch_shape = kernel_reordered.shape[:len_kernel_batch]

sample_conv = convolve_sp(
array_reordered[(0,) * len_array_batch],
kernel_reordered[(0,) * len_kernel_batch],
mode=mode,
)
conv_shape = sample_conv.shape

expected = np.empty(array_batch_shape + kernel_batch_shape + conv_shape)

for idx_array in np.ndindex(array_batch_shape):
array_slice = array_reordered[idx_array]
for idx_kernel in np.ndindex(kernel_batch_shape):
kernel_slice = kernel_reordered[idx_kernel]
expected[idx_array + idx_kernel] = convolve_sp(array_slice, kernel_slice, mode=mode)

return expected


def _prepare_reference_inputs(array, kernel, mode, padding, axes):
"""Apply padding logic to match tidy3d's convolution before building a reference."""

axes_array, axes_kernel = _normalize_axes(array.ndim, kernel.ndim, axes)

working_array = array
scipy_mode = mode

if mode in ("same", "full"):
for ax_array, ax_kernel in zip(axes_array, axes_kernel):
pad_width = (
kernel.shape[ax_kernel] // 2 if mode == "same" else kernel.shape[ax_kernel] - 1
)
if pad_width > 0:
working_array = pad(
working_array, (pad_width, pad_width), mode=padding, axis=ax_array
)
scipy_mode = "valid"

working_array_np = np.asarray(working_array)
kernel_np = np.asarray(kernel)

return working_array_np, kernel_np, axes_array, axes_kernel, scipy_mode


def _reference_convolution(array, kernel, mode, padding, axes):
"""Full reference that mimics tidy3d padding rules before SciPy convolution."""

working_array_np, kernel_np, axes_array, axes_kernel, scipy_mode = _prepare_reference_inputs(
array,
kernel,
mode,
padding,
axes,
)

return _reference_convolve_with_axes(
working_array_np,
kernel_np,
axes_array,
axes_kernel,
scipy_mode,
)


@pytest.mark.parametrize("mode", CONV_MODES)
@pytest.mark.parametrize("padding", PaddingType.__args__)
@pytest.mark.parametrize("axes", _CONVOLVE_AXES_CASES)
class TestConvolveAxes:
def test_convolve_axes_val(self, rng, mode, padding, axes):
"""Test convolution with explicit axes against NumPy implementations."""
array = rng.random((2, 5))
kernel = rng.random((3, 3))

conv_td = convolve(array, kernel, padding=padding, mode=mode, axes=axes)
expected = _reference_convolution(array, kernel, mode, padding, axes)

npt.assert_allclose(conv_td, expected, atol=1e-12)

def test_convolve_axes_grad(self, rng, axes, mode, padding):
"""Test gradients of convolution when specific axes are provided."""
array = rng.random((2, 5))
kernel = rng.random((3, 3))
check_grads(convolve, modes=["rev"], order=2)(
array,
kernel,
padding=padding,
mode=mode,
axes=axes,
)


@pytest.mark.parametrize(
"op,sp_op",
[
Expand Down
168 changes: 154 additions & 14 deletions tidy3d/plugins/autograd/functions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Callable, Literal, Union
from typing import Callable, Literal, SupportsInt, Union

import autograd.numpy as np
import numpy as onp
from autograd import jacobian
from autograd.extend import defvjp, primitive
from autograd.scipy.signal import convolve as convolve_ag
from autograd.numpy.fft import fftn, ifftn
from autograd.scipy.special import logsumexp
from autograd.tracer import getval
from numpy.fft import irfftn, rfftn
from numpy.lib.stride_tricks import sliding_window_view
from numpy.typing import NDArray
from scipy.fft import next_fast_len

from tidy3d.components.autograd.functions import add_at, interpn, trapz

Expand All @@ -37,6 +39,140 @@
]


def _normalize_axes(
ndim_array: int,
ndim_kernel: int,
axes: Union[tuple[Iterable[SupportsInt], Iterable[SupportsInt]], None],
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Normalize the axes specification for convolution."""

def _normalize_single_axis(ax: SupportsInt, ndim: int, kind: str) -> int:
if not isinstance(ax, int):
try:
ax = int(ax)
except Exception as e:
raise TypeError(f"Axis {ax!r} could not be converted to an integer.") from e

if not -ndim <= ax < ndim:
raise ValueError(f"Invalid axis {ax} for {kind} with ndim {ndim}.")
return ax + ndim if ax < 0 else ax

if axes is None:
if ndim_array != ndim_kernel:
raise ValueError(
"Kernel dimensions must match array dimensions when 'axes' is not provided, "
f"got array ndim {ndim_array} and kernel ndim {ndim_kernel}."
)
axes_array = tuple(range(ndim_array))
axes_kernel = tuple(range(ndim_kernel))
return axes_array, axes_kernel

if len(axes) != 2:
raise ValueError("'axes' must be a tuple of two iterable collections of axis indices.")

axes_array_raw, axes_kernel_raw = axes

axes_array = tuple(_normalize_single_axis(ax, ndim_array, "array") for ax in axes_array_raw)
axes_kernel = tuple(_normalize_single_axis(ax, ndim_kernel, "kernel") for ax in axes_kernel_raw)

if len(axes_array) != len(axes_kernel):
raise ValueError(
"The number of convolution axes for the array and kernel must be the same, "
f"got {len(axes_array)} and {len(axes_kernel)}."
)

if len(set(axes_array)) != len(axes_array) or len(set(axes_kernel)) != len(axes_kernel):
raise ValueError("Convolution axes must be unique for both the array and the kernel.")

return axes_array, axes_kernel


def _fft_convolve_general(
array: NDArray,
kernel: NDArray,
axes_array: tuple[int, ...],
axes_kernel: tuple[int, ...],
mode: Literal["full", "valid"],
) -> NDArray:
"""Perform convolution using FFT along the specified axes."""

num_conv_axes = len(axes_array)

if num_conv_axes == 0:
array_shape = array.shape
kernel_shape = kernel.shape
result = np.multiply(
array.reshape(array_shape + (1,) * kernel.ndim),
kernel.reshape((1,) * array.ndim + kernel_shape),
)
return result.reshape(array_shape + kernel_shape)

ignore_axes_array = tuple(ax for ax in range(array.ndim) if ax not in axes_array)
ignore_axes_kernel = tuple(ax for ax in range(kernel.ndim) if ax not in axes_kernel)

new_order_array = ignore_axes_array + axes_array
new_order_kernel = ignore_axes_kernel + axes_kernel

array_reordered = np.transpose(array, new_order_array) if array.ndim else array
kernel_reordered = np.transpose(kernel, new_order_kernel) if kernel.ndim else kernel

num_batch_array = len(ignore_axes_array)
num_batch_kernel = len(ignore_axes_kernel)

array_conv_shape = array_reordered.shape[num_batch_array:]
kernel_conv_shape = kernel_reordered.shape[num_batch_kernel:]

if any(d <= 0 for d in array_conv_shape + kernel_conv_shape):
raise ValueError("Convolution dimensions must be positive; got zero-length axis.")

fft_axes = tuple(range(-num_conv_axes, 0))
fft_shape = [next_fast_len(n + k - 1) for n, k in zip(array_conv_shape, kernel_conv_shape)]
use_real_fft = fft_shape[-1] % 2 == 0 # only applicable in this case

fft_fun = rfftn if use_real_fft else fftn
array_fft = fft_fun(array_reordered, fft_shape, axes=fft_axes)
kernel_fft = fft_fun(kernel_reordered, fft_shape, axes=fft_axes)

if num_batch_kernel:
array_batch_shape = array_fft.shape[:num_batch_array]
conv_shape = array_fft.shape[num_batch_array:]
array_fft = np.reshape(
array_fft,
array_batch_shape + (1,) * num_batch_kernel + conv_shape,
)

if num_batch_array:
kernel_batch_shape = kernel_fft.shape[:num_batch_kernel]
conv_shape = kernel_fft.shape[num_batch_kernel:]
kernel_fft = np.reshape(
kernel_fft,
(1,) * num_batch_array + kernel_batch_shape + conv_shape,
)
use_real_fft = fft_shape[-1] % 2 == 0

product = array_fft * kernel_fft

ifft_fun = irfftn if use_real_fft else ifftn
full_result = ifft_fun(product, fft_shape, axes=fft_axes)

if mode == "full":
result = full_result
elif mode == "valid":
valid_slices = [slice(None)] * full_result.ndim
for axis_offset, (array_dim, kernel_dim) in enumerate(
zip(array_conv_shape, kernel_conv_shape)
):
start = int(min(array_dim, kernel_dim) - 1)
length = int(abs(array_dim - kernel_dim) + 1)
axis = full_result.ndim - num_conv_axes + axis_offset
valid_slices[axis] = slice(start, start + length)
result = full_result[tuple(valid_slices)]
else:
raise ValueError(f"Unsupported convolution mode '{mode}'.")

return np.real(result)


def _get_pad_indices(
n: int,
pad_width: tuple[int, int],
Expand Down Expand Up @@ -148,7 +284,7 @@ def convolve(
kernel: NDArray,
*,
padding: PaddingType = "constant",
axes: Union[tuple[list[int], list[int]], None] = None,
axes: Union[tuple[list[SupportsInt], list[SupportsInt]], None] = None,
mode: Literal["full", "valid", "same"] = "same",
) -> NDArray:
"""Convolve an array with a given kernel.
Expand Down Expand Up @@ -180,19 +316,23 @@ def convolve(
if any(k % 2 == 0 for k in kernel.shape):
raise ValueError(f"All kernel dimensions must be odd, got {kernel.shape}.")

if kernel.ndim != array.ndim and axes is None:
raise ValueError(
f"Kernel dimensions must match array dimensions, got kernel {kernel.shape} and array {array.shape}."
)
axes_array, axes_kernel = _normalize_axes(array.ndim, kernel.ndim, axes)

working_array = array
effective_mode = mode

if mode in ("same", "full"):
kernel_dims = kernel.shape if axes is None else [kernel.shape[d] for d in axes[1]]
pad_widths = [(ks // 2, ks // 2) for ks in kernel_dims]
for axis, pad_width in enumerate(pad_widths):
array = pad(array, pad_width, mode=padding, axis=axis)
mode = "valid" if mode == "same" else mode
if mode in ["same", "full"]:
for ax_array, ax_kernel in zip(axes_array, axes_kernel):
pad_width = (
kernel.shape[ax_kernel] // 2 if mode == "same" else kernel.shape[ax_kernel] - 1
)
if pad_width > 0:
working_array = pad(
working_array, (pad_width, pad_width), mode=padding, axis=ax_array
)
effective_mode = "valid"

return convolve_ag(array, kernel, axes=axes, mode=mode)
return _fft_convolve_general(working_array, kernel, axes_array, axes_kernel, effective_mode)


def _get_footprint(size, structure, maxval):
Expand Down