Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Flake8 upgrade #15527

Merged
merged 21 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
format rllib utils and tests
  • Loading branch information
amogkam committed Apr 27, 2021
commit 845f129cdd1f942f1e7bca8d531ab984255d0ff2
10 changes: 4 additions & 6 deletions rllib/tests/test_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False):
rllib_dir, algo, checkpoint_path, tmp_dir)).read()
if not os.path.exists(tmp_dir + "/rollouts_10steps.pkl"):
sys.exit(1)
print("rollout output (10 steps) exists!".format(checkpoint_path))
print("rollout output (10 steps) exists!")

# Test rolling out 1 episode.
if test_episode_rollout:
Expand All @@ -61,7 +61,7 @@ def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False):
rllib_dir, algo, checkpoint_path, tmp_dir)).read()
if not os.path.exists(tmp_dir + "/rollouts_1episode.pkl"):
sys.exit(1)
print("rollout output (1 ep) exists!".format(checkpoint_path))
print("rollout output (1 ep) exists!")

# Cleanup.
os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
Expand Down Expand Up @@ -115,8 +115,7 @@ def learn_test_plus_rollout(algo, env="CartPole-v0"):
rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1]
if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"):
sys.exit(1)
print("Rollout output exists -> Checking reward ...".format(
checkpoint_path))
print("Rollout output exists -> Checking reward ...")
episodes = result.split("\n")
mean_reward = 0.0
num_episodes = 0
Expand Down Expand Up @@ -208,8 +207,7 @@ def policy_fn(agent):
rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1]
if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"):
sys.exit(1)
print("Rollout output exists -> Checking reward ...".format(
checkpoint_path))
print("Rollout output exists -> Checking reward ...")
episodes = result.split("\n")
mean_reward = 0.0
num_episodes = 0
Expand Down
1 change: 1 addition & 0 deletions rllib/utils/exploration/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING:
from ray.rllib.policy.policy import Policy
import tensorflow as tf

_, nn = try_import_torch()

Expand Down
8 changes: 4 additions & 4 deletions rllib/utils/spaces/space_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ def flatten_space(space):
does not contain Tuples or Dicts anymore.
"""

def _helper_flatten(space_, l):
def _helper_flatten(space_, return_list):
from ray.rllib.utils.spaces.flexdict import FlexDict
if isinstance(space_, Tuple):
for s in space_:
_helper_flatten(s, l)
_helper_flatten(s, return_list)
elif isinstance(space_, (Dict, FlexDict)):
for k in space_.spaces:
_helper_flatten(space_[k], l)
_helper_flatten(space_[k], return_list)
else:
l.append(space_)
return_list.append(space_)

ret = []
_helper_flatten(space, ret)
Expand Down
13 changes: 9 additions & 4 deletions rllib/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Note: Policy config dicts are usually the same as TrainerConfigDict, but
# parts of it may sometimes be altered in e.g. a multi-agent setup,
# where we have >1 Policies in the same Trainer.

TrainerConfigDict = dict

# A trainer config dict that only has overrides. It needs to be combined with
Expand Down Expand Up @@ -69,14 +70,17 @@

# Represents a ViewRequirements dict mapping column names (str) to
# ViewRequirement objects.
ViewRequirementsDict = Dict[str, "ViewRequirement"]
from ray.rllib.policy.view_requirement import ViewRequirement
ViewRequirementsDict = Dict[str, ViewRequirement]

# Represents the result dict returned by Trainer.train().
ResultDict = dict

# A tf or torch local optimizer object.
LocalOptimizer = Union["tf.keras.optimizers.Optimizer",
"torch.optim.Optimizer"]
import tensorflow as tf
import torch
LocalOptimizer = Union[tf.keras.optimizers.Optimizer,
torch.optim.Optimizer]

# Dict of tensors returned by compute gradients on the policy, e.g.,
# {"td_error": [...], "learner_stats": {"vf_loss": ..., ...}}, for multi-agent,
Expand Down Expand Up @@ -104,7 +108,8 @@
ModelInputDict = Dict[str, TensorType]

# Some kind of sample batch.
SampleBatchType = Union["SampleBatch", "MultiAgentBatch"]
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
SampleBatchType = Union[SampleBatch, MultiAgentBatch]

# Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
TensorStructType = Union[TensorType, dict, tuple]
Expand Down