Skip to content

Commit

Permalink
Fix pytype failures related to teaching pytype about NumPy scalar types.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 519128722
  • Loading branch information
hawkinsp authored and MctxDev committed Mar 24, 2023
1 parent 5ad16c2 commit 87e0cd7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion mctx/_src/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def gumbel_muzero_policy(
# a smaller number of valid actions.
considered_visit = jnp.max(summary.visit_counts, axis=-1, keepdims=True)
# The completed_qvalues include imputed values for unvisited actions.
completed_qvalues = jax.vmap(qtransform, in_axes=[0, None])(
completed_qvalues = jax.vmap(qtransform, in_axes=[0, None])( # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long
search_tree, search_tree.ROOT_INDEX)
to_argmax = seq_halving.score_considered(
considered_visit, gumbel, root.prior_logits, completed_qvalues,
Expand Down
9 changes: 5 additions & 4 deletions mctx/_src/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def summary(self) -> SearchSummary:
visit_probs = visit_counts / jnp.maximum(total_counts, 1)
visit_probs = jnp.where(total_counts > 0, visit_probs, 1 / self.num_actions)
# Return relevant stats.
return SearchSummary(
return SearchSummary( # pytype: disable=wrong-arg-types # numpy-scalars
visit_counts=visit_counts,
visit_probs=visit_probs,
value=value,
Expand Down Expand Up @@ -134,6 +134,7 @@ class SearchSummary:

def _unbatched_qvalues(tree: Tree, index: int) -> int:
chex.assert_rank(tree.children_discounts, 2)
return (tree.children_rewards[index] +
tree.children_discounts[index] * tree.children_values[index])

return ( # pytype: disable=bad-return-type # numpy-scalars
tree.children_rewards[index]
+ tree.children_discounts[index] * tree.children_values[index]
)

0 comments on commit 87e0cd7

Please sign in to comment.