Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with Stochastic MuZero #60

Closed
carlosgmartin opened this issue Aug 8, 2023 · 4 comments
Closed

Issues with Stochastic MuZero #60

carlosgmartin opened this issue Aug 8, 2023 · 4 comments

Comments

@carlosgmartin
Copy link
Contributor

I'm having issues with mctx.stochastic_muzero_policy. Here's an example:

import jax
import mctx
from jax import numpy as jnp

num_actions = 4
num_chance_outcomes = 2


def decision_recurrent_fn(params, key, action, state):
    return (
        mctx.DecisionRecurrentFnOutput(
            chance_logits=jnp.full(num_chance_outcomes, 0.0),
            afterstate_value=jnp.array(0.0),
        ),
        state,
    )


def chance_recurrent_fn(params, key, action, afterstate):
    return (
        mctx.ChanceRecurrentFnOutput(
            action_logits=jnp.full(num_actions, 0.0),
            value=jnp.array(0.0),
            # reward=jnp.array(1.),
            reward=1 + (action == 0) * 100,
            discount=jnp.array(0.0),
        ),
        afterstate,
    )


def root_fn(state):
    return mctx.RootFnOutput(
        prior_logits=jnp.full(num_actions, 0.0),
        value=jnp.array(0.0),
        embedding=state,
    )


def main():
    root = root_fn(jnp.full(4, 0.0))
    root = jax.tree_map(lambda x: x[None], root)

    key = jax.random.PRNGKey(0)

    output = mctx.stochastic_muzero_policy(
        params=jnp.full(20, 0.0),
        rng_key=key,
        root=root,
        decision_recurrent_fn=jax.vmap(decision_recurrent_fn, [None, None, 0, 0]),
        chance_recurrent_fn=jax.vmap(chance_recurrent_fn, [None, None, 0, 0]),
        num_simulations=1000,
        num_actions=num_actions,
        num_chance_outcomes=num_chance_outcomes,
    )
    assert (output.search_tree.children_rewards == 0).all()
    print(output.action_weights)  # [[0.007 0.451 0.063 0.479]]


if __name__ == "__main__":
    main()

The first issue is that the children_rewards are all 0, despite the fact that chance_recurrent_fn always yields a positive reward.

The second issue is that the final weight of the zeroth action (which receives an additional reward of 100) is not higher than the rest, despite a large number of simulations.

Any idea what might be causing these issues?

@fidlej
Copy link
Collaborator

fidlej commented Aug 9, 2023

Thanks for sharing the minimal example. I can clear one confusion: The action passed to chance_recurrent_fn(params, key, action, afterstate) is actually the chance outcome. To give different actions different rewards, modify the decision_recurrent_fn to output different afterststate for each action.

You can take an inspiration from the bandit in the tests:
https://github.com/deepmind/mctx/blob/bfb7316b96f9e5b04744e8872c1abba9b2dac6b9/mctx/_src/tests/policies_test.py#L42

I will improve the documentation for the chance_recurrent_fn. Sorry for the confusion.

@carlosgmartin
Copy link
Contributor Author

@fidlej Thanks for your reply. Perhaps the argument can be renamed to outcome, for clarity?

copybara-service bot pushed a commit that referenced this issue Aug 9, 2023
Fixes #60.

PiperOrigin-RevId: 555288953
@carlosgmartin
Copy link
Contributor Author

@fidlej Any idea about the children_rewards issue?

@fidlej
Copy link
Collaborator

fidlej commented Aug 13, 2023

You can see that the output.search_tree contains only the actions relevant for the decision nodes.
The masking is done here:
https://github.com/deepmind/mctx/blob/bfb7316b96f9e5b04744e8872c1abba9b2dac6b9/mctx/_src/policies.py#L366

The zeros in the children_rewards then make sense. The reward is zero for the children of the decision nodes.

copybara-service bot pushed a commit that referenced this issue Aug 15, 2023
Fixes #60.

PiperOrigin-RevId: 555288953
copybara-service bot pushed a commit that referenced this issue Aug 15, 2023
Fixes #60.

PiperOrigin-RevId: 555288953
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants