Skip to content

Test failures in policies_test.py #103

@samuela

Description

@samuela

I'm seeing the following test failures on 2a6919d:

============================= test session starts ==============================
platform linux -- Python 3.12.9, pytest-8.3.5, pluggy-1.5.0
rootdir: /build/source
collected 20 items

mctx/_src/tests/mctx_test.py .                                           [  5%]
mctx/_src/tests/policies_test.py ....FF....                              [ 55%]
mctx/_src/tests/qtransforms_test.py ..                                   [ 65%]
mctx/_src/tests/seq_halving_test.py .......                              [100%]

=================================== FAILURES ===================================
____________________ PoliciesTest.test_gumbel_muzero_policy ____________________

self = <policies_test.PoliciesTest testMethod=test_gumbel_muzero_policy>

    def test_gumbel_muzero_policy(self):
      root_value = jnp.array([-5.0])
      root = mctx.RootFnOutput(
          prior_logits=jnp.array([
              [0.0, -1.0, 2.0, 3.0],
          ]),
          value=root_value,
          embedding=(),
      )
      rewards = jnp.array([
          [20.0, 3.0, -1.0, 10.0],
      ])
      invalid_actions = jnp.array([
          [1.0, 0.0, 0.0, 1.0],
      ])

      value_scale = 0.05
      maxvisit_init = 60
      num_simulations = 17
      max_depth = 3
      qtransform = functools.partial(
          mctx.qtransform_completed_by_mix_value,
          value_scale=value_scale,
          maxvisit_init=maxvisit_init,
          rescale_values=True)
      policy_output = mctx.gumbel_muzero_policy(
          params=(),
          rng_key=jax.random.PRNGKey(0),
          root=root,
          recurrent_fn=_make_bandit_recurrent_fn(rewards),
          num_simulations=num_simulations,
          invalid_actions=invalid_actions,
          max_depth=max_depth,
          qtransform=qtransform,
          gumbel_scale=1.0)
      # Testing the action.
      expected_action = jnp.array([1], dtype=jnp.int32)
>     np.testing.assert_array_equal(expected_action, policy_output.action)

mctx/_src/tests/policies_test.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (Array([1], dtype=int32), Array([2], dtype=int32)), kwargs = {}
old_name = 'y', new_name = 'desired'

    @functools.wraps(fun)
    def wrapper(*args, **kwargs):
        for old_name, new_name in zip(old_names, new_names):
            if old_name in kwargs:
                if dep_version:
                    end_version = dep_version.split('.')
                    end_version[1] = str(int(end_version[1]) + 2)
                    end_version = '.'.join(end_version)
                    msg = (f"Use of keyword argument `{old_name}` is "
                           f"deprecated and replaced by `{new_name}`. "
                           f"Support for `{old_name}` will be removed "
                           f"in NumPy {end_version}.")
                    warnings.warn(msg, DeprecationWarning, stacklevel=2)
                if new_name in kwargs:
                    msg = (f"{fun.__name__}() got multiple values for "
                           f"argument now known as `{new_name}`")
                    raise TypeError(msg)
                kwargs[new_name] = kwargs.pop(old_name)
>       return fun(*args, **kwargs)
E       AssertionError:
E       Arrays are not equal
E
E       Mismatched elements: 1 / 1 (100%)
E       Max absolute difference among violations: 1
E       Max relative difference among violations: 0.5
E        ACTUAL: array([1], dtype=int32)
E        DESIRED: array([2], dtype=int32)

/nix/store/s3k7qby931y3hc7b2phvyay054idkfcg-python3.12-numpy-2.2.3/lib/python3.12/site-packages/numpy/_utils/__init__.py:85: AssertionError
________ PoliciesTest.test_gumbel_muzero_policy_without_invalid_actions ________

self = <policies_test.PoliciesTest testMethod=test_gumbel_muzero_policy_without_invalid_actions>

    def test_gumbel_muzero_policy_without_invalid_actions(self):
      root_value = jnp.array([-5.0])
      root = mctx.RootFnOutput(
          prior_logits=jnp.array([
              [0.0, -1.0, 2.0, 3.0],
          ]),
          value=root_value,
          embedding=(),
      )
      rewards = jnp.array([
          [20.0, 3.0, -1.0, 10.0],
      ])

      value_scale = 0.05
      maxvisit_init = 60
      num_simulations = 17
      max_depth = 3
      qtransform = functools.partial(
          mctx.qtransform_completed_by_mix_value,
          value_scale=value_scale,
          maxvisit_init=maxvisit_init,
          rescale_values=True)
      policy_output = mctx.gumbel_muzero_policy(
          params=(),
          rng_key=jax.random.PRNGKey(0),
          root=root,
          recurrent_fn=_make_bandit_recurrent_fn(rewards),
          num_simulations=num_simulations,
          invalid_actions=None,
          max_depth=max_depth,
          qtransform=qtransform,
          gumbel_scale=1.0)
      # Testing the action.
      expected_action = jnp.array([3], dtype=jnp.int32)
      np.testing.assert_array_equal(expected_action, policy_output.action)

      # Testing the action_weights.
      summary = policy_output.search_tree.summary()
      completed_qvalues = rewards
      max_value = jnp.max(completed_qvalues, axis=-1, keepdims=True)
      min_value = jnp.min(completed_qvalues, axis=-1, keepdims=True)
      total_value_scale = (maxvisit_init + summary.visit_counts.max()
                           ) * value_scale
      rescaled_qvalues = total_value_scale * (completed_qvalues - min_value) / (
          max_value - min_value)
      expected_action_weights = jax.nn.softmax(
          root.prior_logits + rescaled_qvalues)
      np.testing.assert_allclose(expected_action_weights,
                                 policy_output.action_weights,
                                 atol=1e-6)

      # Testing the visit_counts.
      expected_visit_counts = jnp.array(
          [[6, 2, 2, 7]])
>     np.testing.assert_array_equal(expected_visit_counts, summary.visit_counts)

mctx/_src/tests/policies_test.py:307:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (Array([[6, 2, 2, 7]], dtype=int32), Array([[2., 2., 6., 7.]], dtype=float32))
kwargs = {}, old_name = 'y', new_name = 'desired'

    @functools.wraps(fun)
    def wrapper(*args, **kwargs):
        for old_name, new_name in zip(old_names, new_names):
            if old_name in kwargs:
                if dep_version:
                    end_version = dep_version.split('.')
                    end_version[1] = str(int(end_version[1]) + 2)
                    end_version = '.'.join(end_version)
                    msg = (f"Use of keyword argument `{old_name}` is "
                           f"deprecated and replaced by `{new_name}`. "
                           f"Support for `{old_name}` will be removed "
                           f"in NumPy {end_version}.")
                    warnings.warn(msg, DeprecationWarning, stacklevel=2)
                if new_name in kwargs:
                    msg = (f"{fun.__name__}() got multiple values for "
                           f"argument now known as `{new_name}`")
                    raise TypeError(msg)
                kwargs[new_name] = kwargs.pop(old_name)
>       return fun(*args, **kwargs)
E       AssertionError:
E       Arrays are not equal
E
E       Mismatched elements: 2 / 4 (50%)
E       Max absolute difference among violations: 4.
E       Max relative difference among violations: 2.
E        ACTUAL: array([[6, 2, 2, 7]], dtype=int32)
E        DESIRED: array([[2., 2., 6., 7.]], dtype=float32)

/nix/store/s3k7qby931y3hc7b2phvyay054idkfcg-python3.12-numpy-2.2.3/lib/python3.12/site-packages/numpy/_utils/__init__.py:85: AssertionError
=========================== short test summary info ============================
FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy - AssertionError:
FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy_without_invalid_actions - AssertionError:
======================== 2 failed, 18 passed in 14.88s =========================
error: builder for '/nix/store/08wrk930cpjywd59b6x8vyc21p6wc25m-python3.12-mctx-0-unstable-2025-04-04.drv' failed with exit code 1;
       last 25 log lines:
       >                     msg = (f"Use of keyword argument `{old_name}` is "
       >                            f"deprecated and replaced by `{new_name}`. "
       >                            f"Support for `{old_name}` will be removed "
       >                            f"in NumPy {end_version}.")
       >                     warnings.warn(msg, DeprecationWarning, stacklevel=2)
       >                 if new_name in kwargs:
       >                     msg = (f"{fun.__name__}() got multiple values for "
       >                            f"argument now known as `{new_name}`")
       >                     raise TypeError(msg)
       >                 kwargs[new_name] = kwargs.pop(old_name)
       > >       return fun(*args, **kwargs)
       > E       AssertionError:
       > E       Arrays are not equal
       > E
       > E       Mismatched elements: 2 / 4 (50%)
       > E       Max absolute difference among violations: 4.
       > E       Max relative difference among violations: 2.
       > E        ACTUAL: array([[6, 2, 2, 7]], dtype=int32)
       > E        DESIRED: array([[2., 2., 6., 7.]], dtype=float32)
       >
       > /nix/store/s3k7qby931y3hc7b2phvyay054idkfcg-python3.12-numpy-2.2.3/lib/python3.12/site-packages/numpy/_utils/__init__.py:85: AssertionError
       > =========================== short test summary info ============================
       > FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy - AssertionError:
       > FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy_without_invalid_actions - AssertionError:
       > ======================== 2 failed, 18 passed in 14.88s =========================
       For full logs, run 'nix log /nix/store/08wrk930cpjywd59b6x8vyc21p6wc25m-python3.12-mctx-0-unstable-2025-04-04.drv'.
error: 1 dependencies of derivation '/nix/store/1g944w7bws4kkqc3zcryhj4yk8an05bf-python3-3.12.9-env.drv' failed to build

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions