Skip to content

Commit 53bb957

Browse files
committed
WIP: Add comparison with torch, update tests
1 parent 04a8fb8 commit 53bb957

File tree

1 file changed

+42
-34
lines changed

1 file changed

+42
-34
lines changed

tests/experimental/transforms/test_discrete_fourier.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from autoemulate.experimental.transforms.discrete_fourier import (
33
DiscreteFourierTransform,
44
)
5+
from autoemulate.experimental.types import TensorLike
56

67

78
def create_test_data():
@@ -27,13 +28,11 @@ def test_transform_shapes():
2728

2829
# Test inverse transform shape
2930
x_reconstructed = dft.inv(y)
31+
assert isinstance(x_reconstructed, TensorLike)
3032
assert x_reconstructed.shape == x.shape, (
3133
f"Expected shape {x.shape}, got {x_reconstructed.shape}"
3234
)
3335

34-
print(f"✓ Forward transform: {x.shape}{y.shape}")
35-
print(f"✓ Inverse transform: {y.shape}{x_reconstructed.shape}")
36-
3736

3837
def test_basis_matrix_properties():
3938
"""Test that the basis matrix has correct properties."""
@@ -42,17 +41,14 @@ def test_basis_matrix_properties():
4241
dft = DiscreteFourierTransform(n_components=n_components)
4342
dft.fit(x)
4443

45-
A = dft._basis_matrix
44+
A = dft._basis_matrix.T
4645
expected_shape = (2 * n_components, n_features)
4746

4847
assert A.shape == expected_shape, (
4948
f"Expected basis matrix shape {expected_shape}, got {A.shape}"
5049
)
5150
assert A.dtype == torch.float32, f"Expected float32 dtype, got {A.dtype}"
5251

53-
print(f"✓ Basis matrix shape: {A.shape}")
54-
print(f"✓ Basis matrix dtype: {A.dtype}")
55-
5652

5753
def test_matrix_multiplication_consistency():
5854
"""Test that transforms work correctly via matrix multiplication."""
@@ -61,7 +57,7 @@ def test_matrix_multiplication_consistency():
6157
dft = DiscreteFourierTransform(n_components=n_components)
6258
dft.fit(x)
6359

64-
A = dft._basis_matrix
60+
A = dft._basis_matrix.T
6561

6662
# Test forward transform via matrix multiplication
6763
y_transform = dft(x)
@@ -74,14 +70,11 @@ def test_matrix_multiplication_consistency():
7470
# Test inverse transform via matrix multiplication
7571
x_reconstructed = dft.inv(y_transform)
7672
x_manual = y_transform @ A
77-
73+
assert isinstance(x_reconstructed, TensorLike)
7874
assert torch.allclose(x_reconstructed, x_manual, atol=1e-6), (
7975
"Inverse transform doesn't match manual matrix multiplication"
8076
)
8177

82-
print("✓ Forward transform matches manual computation")
83-
print("✓ Inverse transform matches manual computation")
84-
8578

8679
def test_real_valued_output():
8780
"""Test that all outputs are real-valued (no complex numbers)."""
@@ -100,8 +93,6 @@ def test_real_valued_output():
10093
assert not torch.is_complex(y), "Transform output should not be complex"
10194
assert not torch.is_complex(A), "Basis matrix should not be complex"
10295

103-
print("✓ All outputs are real-valued")
104-
10596

10697
def test_frequency_component_pairing():
10798
"""Test that frequency components are properly paired as real/imaginary columns."""
@@ -123,33 +114,50 @@ def test_frequency_component_pairing():
123114
"Output should have even number of columns for real/imag pairs"
124115
)
125116

126-
print(
127-
f"✓ Output has {n_components} frequency components "
128-
f"as {2 * n_components} real/imag paired columns"
129-
)
130117

118+
def test_against_torch_fft():
119+
"""Test matrix-based DFT against PyTorch's FFT implementation."""
120+
x, n_samples, n_features, n_components = create_test_data()
131121

132-
def run_all_tests():
133-
"""Run all test functions."""
134-
print("Running discrete Fourier transform tests...\n")
122+
# Fit the transform to get selected frequency components
123+
dft = DiscreteFourierTransform(n_components=n_components)
124+
dft.fit(x)
135125

136-
test_transform_shapes()
137-
print()
126+
# Get the selected frequency indices
127+
freq_indices = dft.freq_indices
138128

139-
test_basis_matrix_properties()
140-
print()
129+
# Apply our matrix-based transform
130+
y_matrix = dft(x)
141131

142-
test_matrix_multiplication_consistency()
143-
print()
132+
# Apply PyTorch's FFT to the same data
133+
x_fft = torch.fft.fft(x, dim=1) # FFT along feature dimension
144134

145-
test_real_valued_output()
146-
print()
135+
# Extract the same frequency components that our transform selected
136+
selected_fft = x_fft[:, freq_indices] # Shape: (n_samples, n_components)
147137

148-
test_frequency_component_pairing()
149-
print()
138+
# Convert complex FFT output to real/imag pairs format
139+
# PyTorch FFT gives complex numbers, we need [real, imag, real, imag, ...]
140+
fft_real = selected_fft.real # Shape: (n_samples, n_components)
141+
fft_imag = selected_fft.imag # Shape: (n_samples, n_components)
150142

151-
print("All tests passed! ✓")
143+
# Interleave real and imaginary parts to match our format
144+
y_fft_paired = torch.stack([fft_real, fft_imag], dim=2).reshape(
145+
n_samples, 2 * n_components
146+
)
152147

148+
# Account for normalization difference
149+
# Our DFT uses 1/sqrt(N) normalization, PyTorch's doesn't normalize by default
150+
normalization_factor = 1.0 / torch.sqrt(
151+
torch.tensor(n_features, dtype=torch.float32)
152+
)
153+
y_fft_normalized = y_fft_paired * normalization_factor
154+
155+
# Compare the results
156+
max_error = torch.max(torch.abs(y_matrix - y_fft_normalized))
157+
relative_error = max_error / torch.max(torch.abs(y_fft_normalized))
153158

154-
if __name__ == "__main__":
155-
run_all_tests()
159+
# Should be very close (accounting for floating point precision)
160+
assert max_error < 1e-5, (
161+
f"Matrix DFT differs too much from PyTorch FFT: {max_error}"
162+
)
163+
assert relative_error < 1e-4, f"Relative error too large: {relative_error}"

0 commit comments

Comments
 (0)