-
Notifications
You must be signed in to change notification settings - Fork 208
Open
Description
I'm seeing the following test failures on 2a6919d:
============================= test session starts ==============================
platform linux -- Python 3.12.9, pytest-8.3.5, pluggy-1.5.0
rootdir: /build/source
collected 20 items
mctx/_src/tests/mctx_test.py . [ 5%]
mctx/_src/tests/policies_test.py ....FF.... [ 55%]
mctx/_src/tests/qtransforms_test.py .. [ 65%]
mctx/_src/tests/seq_halving_test.py ....... [100%]
=================================== FAILURES ===================================
____________________ PoliciesTest.test_gumbel_muzero_policy ____________________
self = <policies_test.PoliciesTest testMethod=test_gumbel_muzero_policy>
def test_gumbel_muzero_policy(self):
root_value = jnp.array([-5.0])
root = mctx.RootFnOutput(
prior_logits=jnp.array([
[0.0, -1.0, 2.0, 3.0],
]),
value=root_value,
embedding=(),
)
rewards = jnp.array([
[20.0, 3.0, -1.0, 10.0],
])
invalid_actions = jnp.array([
[1.0, 0.0, 0.0, 1.0],
])
value_scale = 0.05
maxvisit_init = 60
num_simulations = 17
max_depth = 3
qtransform = functools.partial(
mctx.qtransform_completed_by_mix_value,
value_scale=value_scale,
maxvisit_init=maxvisit_init,
rescale_values=True)
policy_output = mctx.gumbel_muzero_policy(
params=(),
rng_key=jax.random.PRNGKey(0),
root=root,
recurrent_fn=_make_bandit_recurrent_fn(rewards),
num_simulations=num_simulations,
invalid_actions=invalid_actions,
max_depth=max_depth,
qtransform=qtransform,
gumbel_scale=1.0)
# Testing the action.
expected_action = jnp.array([1], dtype=jnp.int32)
> np.testing.assert_array_equal(expected_action, policy_output.action)
mctx/_src/tests/policies_test.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Array([1], dtype=int32), Array([2], dtype=int32)), kwargs = {}
old_name = 'y', new_name = 'desired'
@functools.wraps(fun)
def wrapper(*args, **kwargs):
for old_name, new_name in zip(old_names, new_names):
if old_name in kwargs:
if dep_version:
end_version = dep_version.split('.')
end_version[1] = str(int(end_version[1]) + 2)
end_version = '.'.join(end_version)
msg = (f"Use of keyword argument `{old_name}` is "
f"deprecated and replaced by `{new_name}`. "
f"Support for `{old_name}` will be removed "
f"in NumPy {end_version}.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
if new_name in kwargs:
msg = (f"{fun.__name__}() got multiple values for "
f"argument now known as `{new_name}`")
raise TypeError(msg)
kwargs[new_name] = kwargs.pop(old_name)
> return fun(*args, **kwargs)
E AssertionError:
E Arrays are not equal
E
E Mismatched elements: 1 / 1 (100%)
E Max absolute difference among violations: 1
E Max relative difference among violations: 0.5
E ACTUAL: array([1], dtype=int32)
E DESIRED: array([2], dtype=int32)
/nix/store/s3k7qby931y3hc7b2phvyay054idkfcg-python3.12-numpy-2.2.3/lib/python3.12/site-packages/numpy/_utils/__init__.py:85: AssertionError
________ PoliciesTest.test_gumbel_muzero_policy_without_invalid_actions ________
self = <policies_test.PoliciesTest testMethod=test_gumbel_muzero_policy_without_invalid_actions>
def test_gumbel_muzero_policy_without_invalid_actions(self):
root_value = jnp.array([-5.0])
root = mctx.RootFnOutput(
prior_logits=jnp.array([
[0.0, -1.0, 2.0, 3.0],
]),
value=root_value,
embedding=(),
)
rewards = jnp.array([
[20.0, 3.0, -1.0, 10.0],
])
value_scale = 0.05
maxvisit_init = 60
num_simulations = 17
max_depth = 3
qtransform = functools.partial(
mctx.qtransform_completed_by_mix_value,
value_scale=value_scale,
maxvisit_init=maxvisit_init,
rescale_values=True)
policy_output = mctx.gumbel_muzero_policy(
params=(),
rng_key=jax.random.PRNGKey(0),
root=root,
recurrent_fn=_make_bandit_recurrent_fn(rewards),
num_simulations=num_simulations,
invalid_actions=None,
max_depth=max_depth,
qtransform=qtransform,
gumbel_scale=1.0)
# Testing the action.
expected_action = jnp.array([3], dtype=jnp.int32)
np.testing.assert_array_equal(expected_action, policy_output.action)
# Testing the action_weights.
summary = policy_output.search_tree.summary()
completed_qvalues = rewards
max_value = jnp.max(completed_qvalues, axis=-1, keepdims=True)
min_value = jnp.min(completed_qvalues, axis=-1, keepdims=True)
total_value_scale = (maxvisit_init + summary.visit_counts.max()
) * value_scale
rescaled_qvalues = total_value_scale * (completed_qvalues - min_value) / (
max_value - min_value)
expected_action_weights = jax.nn.softmax(
root.prior_logits + rescaled_qvalues)
np.testing.assert_allclose(expected_action_weights,
policy_output.action_weights,
atol=1e-6)
# Testing the visit_counts.
expected_visit_counts = jnp.array(
[[6, 2, 2, 7]])
> np.testing.assert_array_equal(expected_visit_counts, summary.visit_counts)
mctx/_src/tests/policies_test.py:307:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Array([[6, 2, 2, 7]], dtype=int32), Array([[2., 2., 6., 7.]], dtype=float32))
kwargs = {}, old_name = 'y', new_name = 'desired'
@functools.wraps(fun)
def wrapper(*args, **kwargs):
for old_name, new_name in zip(old_names, new_names):
if old_name in kwargs:
if dep_version:
end_version = dep_version.split('.')
end_version[1] = str(int(end_version[1]) + 2)
end_version = '.'.join(end_version)
msg = (f"Use of keyword argument `{old_name}` is "
f"deprecated and replaced by `{new_name}`. "
f"Support for `{old_name}` will be removed "
f"in NumPy {end_version}.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
if new_name in kwargs:
msg = (f"{fun.__name__}() got multiple values for "
f"argument now known as `{new_name}`")
raise TypeError(msg)
kwargs[new_name] = kwargs.pop(old_name)
> return fun(*args, **kwargs)
E AssertionError:
E Arrays are not equal
E
E Mismatched elements: 2 / 4 (50%)
E Max absolute difference among violations: 4.
E Max relative difference among violations: 2.
E ACTUAL: array([[6, 2, 2, 7]], dtype=int32)
E DESIRED: array([[2., 2., 6., 7.]], dtype=float32)
/nix/store/s3k7qby931y3hc7b2phvyay054idkfcg-python3.12-numpy-2.2.3/lib/python3.12/site-packages/numpy/_utils/__init__.py:85: AssertionError
=========================== short test summary info ============================
FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy - AssertionError:
FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy_without_invalid_actions - AssertionError:
======================== 2 failed, 18 passed in 14.88s =========================
error: builder for '/nix/store/08wrk930cpjywd59b6x8vyc21p6wc25m-python3.12-mctx-0-unstable-2025-04-04.drv' failed with exit code 1;
last 25 log lines:
> msg = (f"Use of keyword argument `{old_name}` is "
> f"deprecated and replaced by `{new_name}`. "
> f"Support for `{old_name}` will be removed "
> f"in NumPy {end_version}.")
> warnings.warn(msg, DeprecationWarning, stacklevel=2)
> if new_name in kwargs:
> msg = (f"{fun.__name__}() got multiple values for "
> f"argument now known as `{new_name}`")
> raise TypeError(msg)
> kwargs[new_name] = kwargs.pop(old_name)
> > return fun(*args, **kwargs)
> E AssertionError:
> E Arrays are not equal
> E
> E Mismatched elements: 2 / 4 (50%)
> E Max absolute difference among violations: 4.
> E Max relative difference among violations: 2.
> E ACTUAL: array([[6, 2, 2, 7]], dtype=int32)
> E DESIRED: array([[2., 2., 6., 7.]], dtype=float32)
>
> /nix/store/s3k7qby931y3hc7b2phvyay054idkfcg-python3.12-numpy-2.2.3/lib/python3.12/site-packages/numpy/_utils/__init__.py:85: AssertionError
> =========================== short test summary info ============================
> FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy - AssertionError:
> FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy_without_invalid_actions - AssertionError:
> ======================== 2 failed, 18 passed in 14.88s =========================
For full logs, run 'nix log /nix/store/08wrk930cpjywd59b6x8vyc21p6wc25m-python3.12-mctx-0-unstable-2025-04-04.drv'.
error: 1 dependencies of derivation '/nix/store/1g944w7bws4kkqc3zcryhj4yk8an05bf-python3-3.12.9-env.drv' failed to build
Yyassin
Metadata
Metadata
Assignees
Labels
No labels