Skip to content

Commit

Permalink
[RLlib] Fix Atari learning test regressions (2 bugs) and 1 minor atte…
Browse files Browse the repository at this point in the history
…ntion net bug. (#18306)
  • Loading branch information
sven1977 authored Sep 3, 2021
1 parent fb38d06 commit 9a8ca6a
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 177 deletions.
27 changes: 27 additions & 0 deletions release/rllib_tests/learning_tests/hard_learning_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,33 @@ apex-breakoutnoframeskip-v4:
target_network_update_freq: 50000
timesteps_per_iteration: 25000

appo-pong-no-frameskip-v4:
env: PongNoFrameskip-v4
run: APPO
# Minimum reward and total ts (in given time_total_s) to pass this test.
pass_criteria:
episode_reward_mean: 18.0
timesteps_total: 5000000
stop:
time_total_s: 2000
config:
vtrace: True
use_kl_loss: False
rollout_fragment_length: 50
train_batch_size: 750
num_workers: 31
broadcast_interval: 1
max_sample_requests_in_flight_per_worker: 1
num_multi_gpu_tower_stacks: 1
num_envs_per_worker: 8
num_sgd_iter: 2
vf_loss_coeff: 1.0
clip_param: 0.3
num_gpus: 1
grad_clip: 10
model:
dim: 42

ddpg-hopperbulletenv-v0:
env: HopperBulletEnv-v0
run: DDPG
Expand Down
4 changes: 2 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1845,7 +1845,7 @@ py_test(
tags = ["team:ml", "examples", "examples_A"],
size = "medium",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=20"]
args = ["--as-test", "--stop-reward=60"]
)

py_test(
Expand All @@ -1854,7 +1854,7 @@ py_test(
tags = ["team:ml", "examples", "examples_A"],
size = "medium",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=20", "--framework torch"]
args = ["--as-test", "--stop-reward=60", "--framework torch"]
)

py_test(
Expand Down
6 changes: 4 additions & 2 deletions rllib/agents/pg/tests/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@


class TestPG(unittest.TestCase):
def setUp(self):
@classmethod
def setUpClass(cls) -> None:
ray.init()

def tearDown(self):
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()

def test_pg_compilation(self):
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str):
For VizDoom support: Install VizDoom
(https://github.com/mwydmuch/ViZDoom/blob/master/doc/Building.md) and
`pip install vizdoomgym`.
For PyBullet support: `pip install pybullet pybullet_envs`.
For PyBullet support: `pip install pybullet`.
b) To register your custom env, do `from ray import tune;
tune.register('[name]', lambda cfg: [return env obj from here using cfg])`.
Then in your config, do `config['env'] = [name]`.
Expand Down
17 changes: 4 additions & 13 deletions rllib/env/wrappers/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,17 +283,13 @@ def observation(self, observation):
return np.array(observation).astype(np.float32) / 255.0


def wrap_deepmind(
env,
dim=84,
# TODO: (sven) Remove once traj. view is norm.
framestack=True,
framestack_via_traj_view_api=False):
def wrap_deepmind(env, dim=84, framestack=True):
"""Configure environment for DeepMind-style Atari.
Note that we assume reward clipping is done outside the wrapper.
Args:
env (EnvType): The env object to wrap.
dim (int): Dimension to resize observations to (dim x dim).
framestack (bool): Whether to framestack observations.
"""
Expand All @@ -307,12 +303,7 @@ def wrap_deepmind(
env = WarpFrame(env, dim)
# env = ScaledFloatFrame(env) # TODO: use for dqn?
# env = ClipRewardEnv(env) # reward clipping is handled by policy eval
# New way of frame stacking via the trajectory view API (model config key:
# `num_framestacks=[int]`.
if framestack_via_traj_view_api:
env = FrameStackTrajectoryView(env)
# Old way (w/o traj. view API) via model config key: `framestack=True`.
# TODO: (sven) Remove once traj. view is norm.
elif framestack is True:
# 4x image framestacking.
if framestack is True:
env = FrameStack(env, 4)
return env
4 changes: 3 additions & 1 deletion rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,10 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
d.itemsize * int(np.product(d.shape[i + 1:]))
for i in range(1, len(d.shape))
]
start = self.shift_before - shift_win + 1 + obs_shift + \
view_req.shift_to
data = np.lib.stride_tricks.as_strided(
d[self.shift_before - shift_win:],
d[start:start + self.agent_steps],
[self.agent_steps, shift_win
] + [d.shape[i] for i in range(1, len(d.shape))],
[data_size, data_size] + strides)
Expand Down
39 changes: 15 additions & 24 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ray.rllib.utils import force_list, merge_dicts
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.error import EnvError
from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
Expand Down Expand Up @@ -400,7 +400,8 @@ def gen_rollouts():
self.callbacks: DefaultCallbacks = DefaultCallbacks()
self.worker_index: int = worker_index
self.num_workers: int = num_workers
model_config: ModelConfigDict = model_config or {}
model_config: ModelConfigDict = \
model_config or self.policy_config.get("model") or {}

# Default policy mapping fn is to always return DEFAULT_POLICY_ID,
# independent on the agent ID and the episode passed in.
Expand Down Expand Up @@ -464,27 +465,14 @@ def wrap(env):
if clip_rewards is None:
clip_rewards = True

# Deprecated way of framestacking is used.
framestack = model_config.get("framestack") is True
# framestacking via trajectory view API is enabled.
num_framestacks = model_config.get("num_framestacks", 0)

# Trajectory view API is on and num_framestacks=auto:
# Only stack traj. view based if old
# `framestack=[invalid value]`.
if num_framestacks == "auto":
if framestack == DEPRECATED_VALUE:
model_config["num_framestacks"] = num_framestacks = 4
else:
model_config["num_framestacks"] = num_framestacks = 0
framestack_traj_view = num_framestacks > 1
# Framestacking is used.
use_framestack = model_config.get("framestack") is True

def wrap(env):
env = wrap_deepmind(
env,
dim=model_config.get("dim"),
framestack=framestack,
framestack_via_traj_view_api=framestack_traj_view)
framestack=use_framestack)
env = record_env_wrapper(env, record_env, log_dir,
policy_config)
return env
Expand Down Expand Up @@ -740,7 +728,8 @@ def sample(self) -> SampleBatchType:
return self.last_batch
elif self.input_reader is None:
raise ValueError("RolloutWorker has no `input_reader` object! "
"Cannot call `sample()`.")
"Cannot call `sample()`. You can try setting "
"`create_env_on_driver` to True.")

if log_once("sample_start"):
logger.info("Generating sample batch of size {}".format(
Expand Down Expand Up @@ -1423,6 +1412,8 @@ def _determine_spaces_for_multi_agent_dict(
policy_config: Optional[PartialTrainerConfigDict] = None,
) -> MultiAgentPolicyConfigDict:

policy_config = policy_config or {}

# Try extracting spaces from env or from given spaces dict.
env_obs_space = None
env_act_space = None
Expand Down Expand Up @@ -1455,14 +1446,15 @@ def _determine_spaces_for_multi_agent_dict(
obs_space = spaces[pid][0]
elif env_obs_space is not None:
obs_space = env_obs_space
elif policy_config and policy_config.get("observation_space"):
elif policy_config.get("observation_space"):
obs_space = policy_config["observation_space"]
else:
raise ValueError(
"`observation_space` not provided in PolicySpec for "
f"{pid} and env does not have an observation space OR "
"no spaces received from other workers' env(s) OR no "
"`observation_space` specified in config!")

multi_agent_dict[pid] = multi_agent_dict[pid]._replace(
observation_space=obs_space)

Expand All @@ -1471,7 +1463,7 @@ def _determine_spaces_for_multi_agent_dict(
act_space = spaces[pid][1]
elif env_act_space is not None:
act_space = env_act_space
elif policy_config and policy_config.get("action_space"):
elif policy_config.get("action_space"):
act_space = policy_config["action_space"]
else:
raise ValueError(
Expand All @@ -1489,8 +1481,7 @@ def _validate_env(env: EnvType, env_context: EnvContext = None):
msg = f"Validating sub-env at vector index={env_context.vector_index} ..."

allowed_types = [
gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv,
ray.actor.ActorHandle
gym.Env, ExternalEnv, VectorEnv, BaseEnv, ray.actor.ActorHandle
]
if not any(isinstance(env, tpe) for tpe in allowed_types):
# Allow this as a special case (assumed gym.Env).
Expand All @@ -1508,7 +1499,7 @@ def _validate_env(env: EnvType, env_context: EnvContext = None):
f"(type={type(env)}).")

# Do some test runs with the provided env.
if isinstance(env, gym.Env):
if isinstance(env, gym.Env) and not isinstance(env, MultiAgentEnv):
# Make sure the gym.Env has the two space attributes properly set.
assert hasattr(env, "observation_space") and hasattr(
env, "action_space")
Expand Down
1 change: 0 additions & 1 deletion rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def valid_module(class_path):
normalize_actions=config["normalize_actions"],
clip_actions=config["clip_actions"],
env_config=config["env_config"],
model_config=config["model"],
policy_config=config,
worker_index=worker_index,
num_workers=num_workers,
Expand Down
29 changes: 3 additions & 26 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,8 @@
"attention_use_n_prev_rewards": 0,

# == Atari ==
# Which framestacking size to use for Atari envs.
# "auto": Use a value of 4, but only if the env is an Atari env.
# > 1: Use the trajectory view API in the default VisionNets to request the
# last n observations (single, grayscaled 84x84 image frames) as
# inputs. The time axis in the so provided observation tensors
# will come right after the batch axis (channels first format),
# e.g. BxTx84x84, where T=num_framestacks.
# 0 or 1: No framestacking used.
# Use the deprecated `framestack=True`, to disable the above behavor and to
# enable legacy stacking behavior (w/o trajectory view API) instead.
"num_framestacks": "auto",
# Set to True to enable 4x stacking behavior.
"framestack": True,
# Final resized frame dimension
"dim": 84,
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
Expand All @@ -166,8 +157,6 @@
# Deprecated keys:
# Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
"lstm_use_prev_action_reward": DEPRECATED_VALUE,
# Use `num_framestacks` (int) instead.
"framestack": True,
}
# __sphinx_doc_end__
# yapf: enable
Expand Down Expand Up @@ -807,10 +796,6 @@ def _get_v2_model_class(input_space: gym.Space,
"framework={} not supported in `ModelCatalog._get_v2_model_"
"class`!".format(framework))

# Discrete/1D obs-spaces or 2D obs space but traj. view framestacking
# disabled.
num_framestacks = model_config.get("num_framestacks", "auto")

# Tuple space, where at least one sub-space is image.
# -> Complex input model.
space_to_check = input_space if not hasattr(
Expand All @@ -824,8 +809,7 @@ def _get_v2_model_class(input_space: gym.Space,
# Single, flattenable/one-hot-able space -> Simple FCNet.
if isinstance(input_space, (Discrete, MultiDiscrete)) or \
len(input_space.shape) == 1 or (
len(input_space.shape) == 2 and (
num_framestacks == "auto" or num_framestacks <= 1)):
len(input_space.shape) == 2):
# Keras native requested AND no auto-rnn-wrapping.
if model_config.get("_use_default_native_models") and Keras_FCNet:
return Keras_FCNet
Expand Down Expand Up @@ -886,10 +870,3 @@ def _validate_config(config: ModelConfigDict, framework: str) -> None:
elif config.get("use_lstm"):
raise ValueError("`use_lstm` not available for "
"framework=jax so far!")

if config.get("framestack") != DEPRECATED_VALUE:
# deprecation_warning(
# old="framestack", new="num_framestacks (int)", error=False)
# If old behavior is desired, disable traj. view-style
# framestacking.
config["num_framestacks"] = 0
Loading

0 comments on commit 9a8ca6a

Please sign in to comment.