2
2
from autoemulate .experimental .transforms .discrete_fourier import (
3
3
DiscreteFourierTransform ,
4
4
)
5
+ from autoemulate .experimental .types import TensorLike
5
6
6
7
7
8
def create_test_data ():
@@ -27,13 +28,11 @@ def test_transform_shapes():
27
28
28
29
# Test inverse transform shape
29
30
x_reconstructed = dft .inv (y )
31
+ assert isinstance (x_reconstructed , TensorLike )
30
32
assert x_reconstructed .shape == x .shape , (
31
33
f"Expected shape { x .shape } , got { x_reconstructed .shape } "
32
34
)
33
35
34
- print (f"✓ Forward transform: { x .shape } → { y .shape } " )
35
- print (f"✓ Inverse transform: { y .shape } → { x_reconstructed .shape } " )
36
-
37
36
38
37
def test_basis_matrix_properties ():
39
38
"""Test that the basis matrix has correct properties."""
@@ -42,17 +41,14 @@ def test_basis_matrix_properties():
42
41
dft = DiscreteFourierTransform (n_components = n_components )
43
42
dft .fit (x )
44
43
45
- A = dft ._basis_matrix
44
+ A = dft ._basis_matrix . T
46
45
expected_shape = (2 * n_components , n_features )
47
46
48
47
assert A .shape == expected_shape , (
49
48
f"Expected basis matrix shape { expected_shape } , got { A .shape } "
50
49
)
51
50
assert A .dtype == torch .float32 , f"Expected float32 dtype, got { A .dtype } "
52
51
53
- print (f"✓ Basis matrix shape: { A .shape } " )
54
- print (f"✓ Basis matrix dtype: { A .dtype } " )
55
-
56
52
57
53
def test_matrix_multiplication_consistency ():
58
54
"""Test that transforms work correctly via matrix multiplication."""
@@ -61,7 +57,7 @@ def test_matrix_multiplication_consistency():
61
57
dft = DiscreteFourierTransform (n_components = n_components )
62
58
dft .fit (x )
63
59
64
- A = dft ._basis_matrix
60
+ A = dft ._basis_matrix . T
65
61
66
62
# Test forward transform via matrix multiplication
67
63
y_transform = dft (x )
@@ -74,14 +70,11 @@ def test_matrix_multiplication_consistency():
74
70
# Test inverse transform via matrix multiplication
75
71
x_reconstructed = dft .inv (y_transform )
76
72
x_manual = y_transform @ A
77
-
73
+ assert isinstance ( x_reconstructed , TensorLike )
78
74
assert torch .allclose (x_reconstructed , x_manual , atol = 1e-6 ), (
79
75
"Inverse transform doesn't match manual matrix multiplication"
80
76
)
81
77
82
- print ("✓ Forward transform matches manual computation" )
83
- print ("✓ Inverse transform matches manual computation" )
84
-
85
78
86
79
def test_real_valued_output ():
87
80
"""Test that all outputs are real-valued (no complex numbers)."""
@@ -100,8 +93,6 @@ def test_real_valued_output():
100
93
assert not torch .is_complex (y ), "Transform output should not be complex"
101
94
assert not torch .is_complex (A ), "Basis matrix should not be complex"
102
95
103
- print ("✓ All outputs are real-valued" )
104
-
105
96
106
97
def test_frequency_component_pairing ():
107
98
"""Test that frequency components are properly paired as real/imaginary columns."""
@@ -123,33 +114,50 @@ def test_frequency_component_pairing():
123
114
"Output should have even number of columns for real/imag pairs"
124
115
)
125
116
126
- print (
127
- f"✓ Output has { n_components } frequency components "
128
- f"as { 2 * n_components } real/imag paired columns"
129
- )
130
117
118
+ def test_against_torch_fft ():
119
+ """Test matrix-based DFT against PyTorch's FFT implementation."""
120
+ x , n_samples , n_features , n_components = create_test_data ()
131
121
132
- def run_all_tests ():
133
- """Run all test functions."""
134
- print ( "Running discrete Fourier transform tests... \n " )
122
+ # Fit the transform to get selected frequency components
123
+ dft = DiscreteFourierTransform ( n_components = n_components )
124
+ dft . fit ( x )
135
125
136
- test_transform_shapes ()
137
- print ()
126
+ # Get the selected frequency indices
127
+ freq_indices = dft . freq_indices
138
128
139
- test_basis_matrix_properties ()
140
- print ( )
129
+ # Apply our matrix-based transform
130
+ y_matrix = dft ( x )
141
131
142
- test_matrix_multiplication_consistency ()
143
- print ()
132
+ # Apply PyTorch's FFT to the same data
133
+ x_fft = torch . fft . fft ( x , dim = 1 ) # FFT along feature dimension
144
134
145
- test_real_valued_output ()
146
- print ( )
135
+ # Extract the same frequency components that our transform selected
136
+ selected_fft = x_fft [:, freq_indices ] # Shape: (n_samples, n_components )
147
137
148
- test_frequency_component_pairing ()
149
- print ()
138
+ # Convert complex FFT output to real/imag pairs format
139
+ # PyTorch FFT gives complex numbers, we need [real, imag, real, imag, ...]
140
+ fft_real = selected_fft .real # Shape: (n_samples, n_components)
141
+ fft_imag = selected_fft .imag # Shape: (n_samples, n_components)
150
142
151
- print ("All tests passed! ✓" )
143
+ # Interleave real and imaginary parts to match our format
144
+ y_fft_paired = torch .stack ([fft_real , fft_imag ], dim = 2 ).reshape (
145
+ n_samples , 2 * n_components
146
+ )
152
147
148
+ # Account for normalization difference
149
+ # Our DFT uses 1/sqrt(N) normalization, PyTorch's doesn't normalize by default
150
+ normalization_factor = 1.0 / torch .sqrt (
151
+ torch .tensor (n_features , dtype = torch .float32 )
152
+ )
153
+ y_fft_normalized = y_fft_paired * normalization_factor
154
+
155
+ # Compare the results
156
+ max_error = torch .max (torch .abs (y_matrix - y_fft_normalized ))
157
+ relative_error = max_error / torch .max (torch .abs (y_fft_normalized ))
153
158
154
- if __name__ == "__main__" :
155
- run_all_tests ()
159
+ # Should be very close (accounting for floating point precision)
160
+ assert max_error < 1e-5 , (
161
+ f"Matrix DFT differs too much from PyTorch FFT: { max_error } "
162
+ )
163
+ assert relative_error < 1e-4 , f"Relative error too large: { relative_error } "
0 commit comments