22from autoemulate .experimental .transforms .discrete_fourier import (
33 DiscreteFourierTransform ,
44)
5+ from autoemulate .experimental .types import TensorLike
56
67
78def create_test_data ():
@@ -27,13 +28,11 @@ def test_transform_shapes():
2728
2829 # Test inverse transform shape
2930 x_reconstructed = dft .inv (y )
31+ assert isinstance (x_reconstructed , TensorLike )
3032 assert x_reconstructed .shape == x .shape , (
3133 f"Expected shape { x .shape } , got { x_reconstructed .shape } "
3234 )
3335
34- print (f"✓ Forward transform: { x .shape } → { y .shape } " )
35- print (f"✓ Inverse transform: { y .shape } → { x_reconstructed .shape } " )
36-
3736
3837def test_basis_matrix_properties ():
3938 """Test that the basis matrix has correct properties."""
@@ -42,17 +41,14 @@ def test_basis_matrix_properties():
4241 dft = DiscreteFourierTransform (n_components = n_components )
4342 dft .fit (x )
4443
45- A = dft ._basis_matrix
44+ A = dft ._basis_matrix . T
4645 expected_shape = (2 * n_components , n_features )
4746
4847 assert A .shape == expected_shape , (
4948 f"Expected basis matrix shape { expected_shape } , got { A .shape } "
5049 )
5150 assert A .dtype == torch .float32 , f"Expected float32 dtype, got { A .dtype } "
5251
53- print (f"✓ Basis matrix shape: { A .shape } " )
54- print (f"✓ Basis matrix dtype: { A .dtype } " )
55-
5652
5753def test_matrix_multiplication_consistency ():
5854 """Test that transforms work correctly via matrix multiplication."""
@@ -61,7 +57,7 @@ def test_matrix_multiplication_consistency():
6157 dft = DiscreteFourierTransform (n_components = n_components )
6258 dft .fit (x )
6359
64- A = dft ._basis_matrix
60+ A = dft ._basis_matrix . T
6561
6662 # Test forward transform via matrix multiplication
6763 y_transform = dft (x )
@@ -74,14 +70,11 @@ def test_matrix_multiplication_consistency():
7470 # Test inverse transform via matrix multiplication
7571 x_reconstructed = dft .inv (y_transform )
7672 x_manual = y_transform @ A
77-
73+ assert isinstance ( x_reconstructed , TensorLike )
7874 assert torch .allclose (x_reconstructed , x_manual , atol = 1e-6 ), (
7975 "Inverse transform doesn't match manual matrix multiplication"
8076 )
8177
82- print ("✓ Forward transform matches manual computation" )
83- print ("✓ Inverse transform matches manual computation" )
84-
8578
8679def test_real_valued_output ():
8780 """Test that all outputs are real-valued (no complex numbers)."""
@@ -100,8 +93,6 @@ def test_real_valued_output():
10093 assert not torch .is_complex (y ), "Transform output should not be complex"
10194 assert not torch .is_complex (A ), "Basis matrix should not be complex"
10295
103- print ("✓ All outputs are real-valued" )
104-
10596
10697def test_frequency_component_pairing ():
10798 """Test that frequency components are properly paired as real/imaginary columns."""
@@ -123,33 +114,50 @@ def test_frequency_component_pairing():
123114 "Output should have even number of columns for real/imag pairs"
124115 )
125116
126- print (
127- f"✓ Output has { n_components } frequency components "
128- f"as { 2 * n_components } real/imag paired columns"
129- )
130117
118+ def test_against_torch_fft ():
119+ """Test matrix-based DFT against PyTorch's FFT implementation."""
120+ x , n_samples , n_features , n_components = create_test_data ()
131121
132- def run_all_tests ():
133- """Run all test functions."""
134- print ( "Running discrete Fourier transform tests... \n " )
122+ # Fit the transform to get selected frequency components
123+ dft = DiscreteFourierTransform ( n_components = n_components )
124+ dft . fit ( x )
135125
136- test_transform_shapes ()
137- print ()
126+ # Get the selected frequency indices
127+ freq_indices = dft . freq_indices
138128
139- test_basis_matrix_properties ()
140- print ( )
129+ # Apply our matrix-based transform
130+ y_matrix = dft ( x )
141131
142- test_matrix_multiplication_consistency ()
143- print ()
132+ # Apply PyTorch's FFT to the same data
133+ x_fft = torch . fft . fft ( x , dim = 1 ) # FFT along feature dimension
144134
145- test_real_valued_output ()
146- print ( )
135+ # Extract the same frequency components that our transform selected
136+ selected_fft = x_fft [:, freq_indices ] # Shape: (n_samples, n_components )
147137
148- test_frequency_component_pairing ()
149- print ()
138+ # Convert complex FFT output to real/imag pairs format
139+ # PyTorch FFT gives complex numbers, we need [real, imag, real, imag, ...]
140+ fft_real = selected_fft .real # Shape: (n_samples, n_components)
141+ fft_imag = selected_fft .imag # Shape: (n_samples, n_components)
150142
151- print ("All tests passed! ✓" )
143+ # Interleave real and imaginary parts to match our format
144+ y_fft_paired = torch .stack ([fft_real , fft_imag ], dim = 2 ).reshape (
145+ n_samples , 2 * n_components
146+ )
152147
148+ # Account for normalization difference
149+ # Our DFT uses 1/sqrt(N) normalization, PyTorch's doesn't normalize by default
150+ normalization_factor = 1.0 / torch .sqrt (
151+ torch .tensor (n_features , dtype = torch .float32 )
152+ )
153+ y_fft_normalized = y_fft_paired * normalization_factor
154+
155+ # Compare the results
156+ max_error = torch .max (torch .abs (y_matrix - y_fft_normalized ))
157+ relative_error = max_error / torch .max (torch .abs (y_fft_normalized ))
153158
154- if __name__ == "__main__" :
155- run_all_tests ()
159+ # Should be very close (accounting for floating point precision)
160+ assert max_error < 1e-5 , (
161+ f"Matrix DFT differs too much from PyTorch FFT: { max_error } "
162+ )
163+ assert relative_error < 1e-4 , f"Relative error too large: { relative_error } "
0 commit comments