Skip to content

Several tests of keras/src/ops/numpy_test.py::NumpyDtypeTest fail with AssertionError #21629

@GaetanLepage

Description

@GaetanLepage

When using jax==0.7.1, several rests of keras/src/ops/numpy_test.py::NumpyDtypeTest fail with:

FAILED keras/src/ops/numpy_test.py::NumpyDtypeTest::test_kaiser_uint8 - AssertionError:
- float32
+ float64

For example:

______________________ NumpyDtypeTest.test_hanning_int64 _______________________

self = <keras.src.ops.numpy_test.NumpyDtypeTest testMethod=test_hanning_int64>
dtype = 'int64'

    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
    def test_hanning(self, dtype):
        import jax.numpy as jnp

        x = knp.ones((), dtype=dtype)
        x_jax = jnp.ones((), dtype=dtype)
        expected_dtype = standardize_dtype(jnp.hanning(x_jax).dtype)

>       self.assertEqual(
            standardize_dtype(knp.hanning(x).dtype), expected_dtype
        )
E       AssertionError:
E       - float32
E       + float64

keras/src/ops/numpy_test.py:5719: AssertionError

Context: bump jax from 0.6.0 to 0.7.1 in nixpkgs.

Full logs: https://paste.glepage.com/raw/snail-otter-frog

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions