Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions spectral_connectivity/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,13 @@ def _noise_covariance(self) -> NDArray[np.floating]:

@property
def _MVAR_Fourier_coefficients(self) -> NDArray[np.complexfloating]:
return xp.linalg.inv(self._transfer_function)
H = self._transfer_function
# Tikhonov regularization: solve(H + λI, I) instead of inv(H)
# Scale-aware regularization parameter
lam = 1e-12 * xp.mean(xp.real(xp.conj(H) * H))
identity = xp.eye(H.shape[-1], dtype=H.dtype)
regularized_H = H + lam * identity
return xp.linalg.solve(regularized_H, identity)

@property
def _expectation(self) -> Callable:
Expand Down Expand Up @@ -1555,9 +1561,13 @@ def _estimate_transfer_function(

"""
inverse_fourier_coefficients = ifft(minimum_phase, axis=-3).real
return xp.matmul(
minimum_phase, xp.linalg.inv(inverse_fourier_coefficients[..., 0:1, :, :])
)
H_0 = inverse_fourier_coefficients[..., 0:1, :, :]
# Tikhonov regularization: solve(H_0 + λI, I) instead of inv(H_0)
lam = 1e-12 * xp.mean(H_0 * H_0) # Scale-aware regularization for real matrix
Copy link
Preview

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regularization parameter uses the same magic number 1e-12 as in the other function. Consider extracting this to a shared constant to ensure consistency and easier maintenance.

Copilot uses AI. Check for mistakes.

identity = xp.eye(H_0.shape[-1], dtype=H_0.dtype)
regularized_H_0 = H_0 + lam * identity
H_0_inv = xp.linalg.solve(regularized_H_0, identity)
return xp.matmul(minimum_phase, H_0_inv)


def _estimate_predictive_power(
Expand Down
53 changes: 53 additions & 0 deletions tests/test_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,56 @@ def test_subset_pairwise_granger_prediction():
for i, j in pairs:
assert np.allclose(gp_subset[..., i, j], gp_all[..., i, j], equal_nan=True)
assert np.allclose(gp_subset[..., j, i], gp_all[..., j, i], equal_nan=True)


def test_mvar_regularized_inverse_near_singular():
"""Test regularized inverse handles near-singular frequency bins."""
np.random.seed(42)
n_time_samples, n_trials, n_tapers, n_fft_samples, n_signals = (
1, 10, 1, 5, 3
)

# Create nearly singular Fourier coefficients by making signals
# highly correlated
fourier_coefficients = np.zeros(
(n_time_samples, n_trials, n_tapers, n_fft_samples, n_signals),
dtype=complex,
)

# Base signal
base_signal = np.random.randn(
n_time_samples, n_trials, n_tapers, n_fft_samples
) + 1j * np.random.randn(
n_time_samples, n_trials, n_tapers, n_fft_samples
)

# Create near-singular cross-spectral matrix by making signals
# nearly dependent
fourier_coefficients[..., 0] = base_signal
fourier_coefficients[..., 1] = base_signal + 1e-10 * (
np.random.randn(n_time_samples, n_trials, n_tapers, n_fft_samples)
+ 1j * np.random.randn(n_time_samples, n_trials, n_tapers, n_fft_samples)
)
fourier_coefficients[..., 2] = base_signal + 1e-10 * (
np.random.randn(n_time_samples, n_trials, n_tapers, n_fft_samples)
+ 1j * np.random.randn(n_time_samples, n_trials, n_tapers, n_fft_samples)
)

# This should not raise LinAlgError with regularized inverse
conn = Connectivity(fourier_coefficients=fourier_coefficients)

# Test that MVAR coefficients are computed without error
mvar_coeffs = conn._MVAR_Fourier_coefficients
assert mvar_coeffs is not None
assert np.all(np.isfinite(mvar_coeffs))

# Test that transfer function is computed without error
transfer_func = conn._transfer_function
assert transfer_func is not None
assert np.all(np.isfinite(transfer_func))

# Test connectivity measures that depend on MVAR work
dtf = conn.directed_transfer_function()
assert np.all(np.isfinite(dtf))
assert np.all(dtf >= 0) # DTF should be non-negative
assert np.all(dtf <= 1) # DTF should be bounded by 1
Loading