Skip to content

Commit 04a8fb8

Browse files
committed
Add discrete fourier transform to tests
1 parent 6f0c87a commit 04a8fb8

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

tests/experimental/transforms/test_transforms.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from autoemulate.experimental.emulators import GaussianProcessExact
55
from autoemulate.experimental.transforms import (
6+
DiscreteFourierTransform,
67
PCATransform,
78
StandardizeTransform,
89
VAETransform,
@@ -18,6 +19,7 @@
1819
(PCATransform(n_components=2), (20, 2)),
1920
(VAETransform(latent_dim=2), (20, 2)),
2021
(StandardizeTransform(), (20, 5)),
22+
(DiscreteFourierTransform(n_components=2), (20, 2)),
2123
],
2224
)
2325
def test_transform_shapes(sample_data_y2d, transform, expected_shape):
@@ -30,7 +32,12 @@ def test_transform_shapes(sample_data_y2d, transform, expected_shape):
3032

3133
@pytest.mark.parametrize(
3234
("transform"),
33-
[PCATransform(n_components=2), VAETransform(latent_dim=2), StandardizeTransform()],
35+
[
36+
DiscreteFourierTransform(n_components=2),
37+
PCATransform(n_components=2),
38+
VAETransform(latent_dim=2),
39+
StandardizeTransform(),
40+
],
3441
)
3542
def test_transform_inverse_for_gaussians(sample_data_y2d, transform):
3643
x, y = sample_data_y2d

0 commit comments

Comments
 (0)