When using `jax==0.7.1`, several rests of `keras/src/ops/numpy_test.py::NumpyDtypeTest` fail with: ``` FAILED keras/src/ops/numpy_test.py::NumpyDtypeTest::test_kaiser_uint8 - AssertionError: - float32 + float64 ``` For example: ``` ______________________ NumpyDtypeTest.test_hanning_int64 _______________________ self = <keras.src.ops.numpy_test.NumpyDtypeTest testMethod=test_hanning_int64> dtype = 'int64' @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_hanning(self, dtype): import jax.numpy as jnp x = knp.ones((), dtype=dtype) x_jax = jnp.ones((), dtype=dtype) expected_dtype = standardize_dtype(jnp.hanning(x_jax).dtype) > self.assertEqual( standardize_dtype(knp.hanning(x).dtype), expected_dtype ) E AssertionError: E - float32 E + float64 keras/src/ops/numpy_test.py:5719: AssertionError ``` Context: [bump jax from 0.6.0 to 0.7.1 in nixpkgs](https://github.com/NixOS/nixpkgs/pull/427588). Full logs: https://paste.glepage.com/raw/snail-otter-frog