Skip to content

Commit

Permalink
Fix a shape error in stochastic_muzero_policy (thanks Evan Walters).
Browse files Browse the repository at this point in the history
Fixes #37

PiperOrigin-RevId: 505342675
  • Loading branch information
fidlej authored and MctxDev committed Jan 28, 2023
1 parent 6b827d2 commit 577fc77
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
6 changes: 3 additions & 3 deletions mctx/_src/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,11 @@ def stochastic_muzero_policy(
decision_recurrent_fn: a callable to be called on the leaf decision nodes
and unvisited actions retrieved by the simulation step, which takes as
args `(params, rng_key, action, state_embedding)` and returns a
`DecisionRecurrentFnOutput`.
`(DecisionRecurrentFnOutput, afterstate_embedding)`.
chance_recurrent_fn: a callable to be called on the leaf chance nodes and
unvisited actions retrieved by the simulation step, which takes as args
`(params, rng_key, action, afterstate_embedding)` and returns a
`ChanceRecurrentFnOutput`.
`(ChanceRecurrentFnOutput, state_embedding)`.
num_simulations: the number of simulations.
num_actions: number of environment actions.
num_chance_outcomes: number of chance outcomes following an afterstate.
Expand Down Expand Up @@ -471,7 +471,7 @@ def stochastic_recurrent_fn(
is_decision_node=jnp.logical_not(state.is_decision_node))

def _broadcast_where(decision_leaf, chance_leaf):
extra_dims = [1] * (len(decision_leaf) - 1)
extra_dims = [1] * (len(decision_leaf.shape) - 1)
expanded_is_decision = jnp.reshape(state.is_decision_node,
[-1] + extra_dims)
return jnp.where(
Expand Down
20 changes: 12 additions & 8 deletions mctx/_src/tests/policies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,28 @@ def recurrent_fn(params, rng_key, action, embedding):
def _make_bandit_decision_and_chance_fns(rewards, num_chance_outcomes):

def decision_recurrent_fn(params, rng_key, action, embedding):
del params, rng_key, embedding
del params, rng_key
batch_size = action.shape[0]
reward = rewards[jnp.arange(batch_size), action]
dummy_chance_logits = jnp.full([batch_size, num_chance_outcomes],
-jnp.inf).at[:, 0].set(1.0)
afterstate_embedding = (action, embedding)
return mctx.DecisionRecurrentFnOutput(
chance_logits=dummy_chance_logits,
afterstate_value=jnp.zeros_like(reward)), (action)
afterstate_value=jnp.zeros_like(reward)), afterstate_embedding

def chance_recurrent_fn(params, rng_key, chance_outcome, embedding):
def chance_recurrent_fn(params, rng_key, chance_outcome,
afterstate_embedding):
del params, rng_key, chance_outcome
afterstate_action = embedding
afterstate_action, embedding = afterstate_embedding
batch_size = afterstate_action.shape[0]

reward = rewards[jnp.arange(batch_size), afterstate_action]
return mctx.ChanceRecurrentFnOutput(
action_logits=jnp.zeros_like(rewards),
value=jnp.zeros_like(reward),
discount=jnp.zeros_like(reward),
reward=reward), jnp.zeros([1, 4])
reward=reward), embedding

return decision_recurrent_fn, chance_recurrent_fn

Expand Down Expand Up @@ -309,13 +311,15 @@ def test_stochastic_muzero_policy(self):
root = mctx.RootFnOutput(
prior_logits=jnp.array([
[-1.0, 0.0, 2.0, 3.0],
[0.0, 2.0, 5.0, -4.0],
]),
value=jnp.array([0.0]),
embedding=jnp.zeros([1, 4]),
value=jnp.array([1.0, 0.0]),
embedding=jnp.zeros([2, 4])
)
rewards = jnp.zeros_like(root.prior_logits)
invalid_actions = jnp.array([
[0.0, 0.0, 0.0, 1.0],
[1.0, 0.0, 1.0, 0.0],
])

num_simulations = 10
Expand All @@ -326,7 +330,7 @@ def test_stochastic_muzero_policy(self):
root=root,
recurrent_fn=_make_bandit_recurrent_fn(
rewards,
dummy_embedding=jnp.zeros([1, 4])),
dummy_embedding=jnp.zeros_like(root.embedding)),
num_simulations=num_simulations,
invalid_actions=invalid_actions,
dirichlet_fraction=0.0)
Expand Down

0 comments on commit 577fc77

Please sign in to comment.