From 3b8e8bcab019e90eb102ca9403e0519d3b5cbe16 Mon Sep 17 00:00:00 2001 From: Ivo Danihelka Date: Fri, 10 Mar 2023 03:04:13 -0800 Subject: [PATCH] Fix pylint warnings. PiperOrigin-RevId: 515586552 --- .pylintrc | 20 ++++---------------- mctx/_src/policies.py | 6 +++--- mctx/_src/search.py | 2 +- mctx/_src/tests/tree_test.py | 1 + 4 files changed, 9 insertions(+), 20 deletions(-) diff --git a/.pylintrc b/.pylintrc index 9cdfb9c..140a876 100644 --- a/.pylintrc +++ b/.pylintrc @@ -76,7 +76,7 @@ disable=abstract-method, global-statement, hex-method, idiv-method, - implicit-str-concat-in-sequence, + implicit-str-concat, import-error, import-self, import-star-module-level, @@ -155,12 +155,6 @@ disable=abstract-method, # mypackage.mymodule.MyReporterClass. output-format=text -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". This option is deprecated -# and it will be removed in Pylint 2.0. -files-output=no - # Tells whether to display a full report or only the messages reports=no @@ -279,12 +273,6 @@ ignore-long-lines=(?x)( # else. single-line-if-stmt=yes -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check= - # Maximum number of lines in a module max-module-lines=99999 @@ -436,6 +424,6 @@ valid-metaclass-classmethod-first-arg=mcs # Exceptions that will emit a warning when being caught. Defaults to # "Exception" -overgeneral-exceptions=StandardError, - Exception, - BaseException +overgeneral-exceptions=builtins.StandardError, + builtins.Exception, + builtins.BaseException diff --git a/mctx/_src/policies.py b/mctx/_src/policies.py index 8d0c71f..8ae4a3a 100644 --- a/mctx/_src/policies.py +++ b/mctx/_src/policies.py @@ -307,7 +307,7 @@ def stochastic_muzero_policy( prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions)) # construct a dummy afterstate embedding - batch_size = jax.tree_leaves(root.embedding)[0].shape[0] + batch_size = jax.tree_util.tree_leaves(root.embedding)[0].shape[0] dummy_action = jnp.zeros([batch_size], dtype=jnp.int32) _, dummy_afterstate_embedding = decision_recurrent_fn(params, rng_key, dummy_action, @@ -428,7 +428,7 @@ def stochastic_recurrent_fn( action_or_chance: base.Action, # [B] state: base.StochasticRecurrentState ) -> Tuple[base.RecurrentFnOutput, base.StochasticRecurrentState]: - batch_size = jax.tree_leaves(state.state_embedding)[0].shape[0] + batch_size = jax.tree_util.tree_leaves(state.state_embedding)[0].shape[0] # Internally we assume that there are `A' = A + C` "actions"; # action_or_chance can take on values in `{0, 1, ..., A' - 1}`,. # To interpret it as an action we can leave it as is: @@ -509,7 +509,7 @@ def _take_slice(x): elif mode == 'chance': return x[..., num_actions:] else: - raise Exception(f'Unknown mode: {mode}.') + raise ValueError(f'Unknown mode: {mode}.') return tree.replace( children_index=_take_slice(tree.children_index), diff --git a/mctx/_src/search.py b/mctx/_src/search.py index 1ccac93..d14610a 100644 --- a/mctx/_src/search.py +++ b/mctx/_src/search.py @@ -324,7 +324,7 @@ def update_tree_node( # When using max_depth, a leaf can be expanded multiple times. new_visit = tree.node_visits[batch_range, node_index] + 1 - updates = dict( + updates = dict( # pylint: disable=use-dict-literal children_prior_logits=batch_update( tree.children_prior_logits, prior_logits, node_index), raw_values=batch_update( diff --git a/mctx/_src/tests/tree_test.py b/mctx/_src/tests/tree_test.py index f42025d..370fae3 100644 --- a/mctx/_src/tests/tree_test.py +++ b/mctx/_src/tests/tree_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """A unit test comparing the search tree to an expected search tree.""" +# pylint: disable=use-dict-literal import functools import json