Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix Atari learning test regressions (2 bugs) and 1 minor attention net bug. #18306

Merged
merged 12 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions release/rllib_tests/learning_tests/hard_learning_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,33 @@ apex-breakoutnoframeskip-v4:
target_network_update_freq: 50000
timesteps_per_iteration: 25000

appo-pong-no-frameskip-v4:
env: PongNoFrameskip-v4
run: APPO
# Minimum reward and total ts (in given time_total_s) to pass this test.
pass_criteria:
episode_reward_mean: 18.0
timesteps_total: 5000000
stop:
time_total_s: 2000
config:
vtrace: True
use_kl_loss: False
rollout_fragment_length: 50
train_batch_size: 750
num_workers: 31
broadcast_interval: 1
max_sample_requests_in_flight_per_worker: 1
num_multi_gpu_tower_stacks: 1
num_envs_per_worker: 8
num_sgd_iter: 2
vf_loss_coeff: 1.0
clip_param: 0.3
num_gpus: 1
grad_clip: 10
model:
dim: 42

ddpg-hopperbulletenv-v0:
env: HopperBulletEnv-v0
run: DDPG
Expand Down
4 changes: 2 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1845,7 +1845,7 @@ py_test(
tags = ["team:ml", "examples", "examples_A"],
size = "medium",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=20"]
args = ["--as-test", "--stop-reward=60"]
)

py_test(
Expand All @@ -1854,7 +1854,7 @@ py_test(
tags = ["team:ml", "examples", "examples_A"],
size = "medium",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=20", "--framework torch"]
args = ["--as-test", "--stop-reward=60", "--framework torch"]
)

py_test(
Expand Down
6 changes: 4 additions & 2 deletions rllib/agents/pg/tests/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@


class TestPG(unittest.TestCase):
def setUp(self):
@classmethod
def setUpClass(cls) -> None:
ray.init()

def tearDown(self):
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()

def test_pg_compilation(self):
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str):
For VizDoom support: Install VizDoom
(https://github.com/mwydmuch/ViZDoom/blob/master/doc/Building.md) and
`pip install vizdoomgym`.
For PyBullet support: `pip install pybullet pybullet_envs`.
For PyBullet support: `pip install pybullet`.
b) To register your custom env, do `from ray import tune;
tune.register('[name]', lambda cfg: [return env obj from here using cfg])`.
Then in your config, do `config['env'] = [name]`.
Expand Down
17 changes: 4 additions & 13 deletions rllib/env/wrappers/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,17 +283,13 @@ def observation(self, observation):
return np.array(observation).astype(np.float32) / 255.0


def wrap_deepmind(
env,
dim=84,
# TODO: (sven) Remove once traj. view is norm.
framestack=True,
framestack_via_traj_view_api=False):
def wrap_deepmind(env, dim=84, framestack=True):
"""Configure environment for DeepMind-style Atari.

Note that we assume reward clipping is done outside the wrapper.

Args:
env (EnvType): The env object to wrap.
dim (int): Dimension to resize observations to (dim x dim).
framestack (bool): Whether to framestack observations.
"""
Expand All @@ -307,12 +303,7 @@ def wrap_deepmind(
env = WarpFrame(env, dim)
# env = ScaledFloatFrame(env) # TODO: use for dqn?
# env = ClipRewardEnv(env) # reward clipping is handled by policy eval
# New way of frame stacking via the trajectory view API (model config key:
# `num_framestacks=[int]`.
if framestack_via_traj_view_api:
env = FrameStackTrajectoryView(env)
# Old way (w/o traj. view API) via model config key: `framestack=True`.
# TODO: (sven) Remove once traj. view is norm.
elif framestack is True:
# 4x image framestacking.
if framestack is True:
env = FrameStack(env, 4)
return env
4 changes: 3 additions & 1 deletion rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,10 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
d.itemsize * int(np.product(d.shape[i + 1:]))
for i in range(1, len(d.shape))
]
start = self.shift_before - shift_win + 1 + obs_shift + \
view_req.shift_to
data = np.lib.stride_tricks.as_strided(
d[self.shift_before - shift_win:],
d[start:start + self.agent_steps],
[self.agent_steps, shift_win
] + [d.shape[i] for i in range(1, len(d.shape))],
[data_size, data_size] + strides)
Expand Down
39 changes: 15 additions & 24 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ray.rllib.utils import force_list, merge_dicts
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.error import EnvError
from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
Expand Down Expand Up @@ -400,7 +400,8 @@ def gen_rollouts():
self.callbacks: DefaultCallbacks = DefaultCallbacks()
self.worker_index: int = worker_index
self.num_workers: int = num_workers
model_config: ModelConfigDict = model_config or {}
model_config: ModelConfigDict = \
model_config or self.policy_config.get("model") or {}

# Default policy mapping fn is to always return DEFAULT_POLICY_ID,
# independent on the agent ID and the episode passed in.
Expand Down Expand Up @@ -464,27 +465,14 @@ def wrap(env):
if clip_rewards is None:
clip_rewards = True

# Deprecated way of framestacking is used.
framestack = model_config.get("framestack") is True
# framestacking via trajectory view API is enabled.
num_framestacks = model_config.get("num_framestacks", 0)

# Trajectory view API is on and num_framestacks=auto:
# Only stack traj. view based if old
# `framestack=[invalid value]`.
if num_framestacks == "auto":
if framestack == DEPRECATED_VALUE:
model_config["num_framestacks"] = num_framestacks = 4
else:
model_config["num_framestacks"] = num_framestacks = 0
framestack_traj_view = num_framestacks > 1
# Framestacking is used.
use_framestack = model_config.get("framestack") is True

def wrap(env):
env = wrap_deepmind(
env,
dim=model_config.get("dim"),
framestack=framestack,
framestack_via_traj_view_api=framestack_traj_view)
framestack=use_framestack)
env = record_env_wrapper(env, record_env, log_dir,
policy_config)
return env
Expand Down Expand Up @@ -740,7 +728,8 @@ def sample(self) -> SampleBatchType:
return self.last_batch
elif self.input_reader is None:
raise ValueError("RolloutWorker has no `input_reader` object! "
"Cannot call `sample()`.")
"Cannot call `sample()`. You can try setting "
"`create_env_on_driver` to True.")

if log_once("sample_start"):
logger.info("Generating sample batch of size {}".format(
Expand Down Expand Up @@ -1423,6 +1412,8 @@ def _determine_spaces_for_multi_agent_dict(
policy_config: Optional[PartialTrainerConfigDict] = None,
) -> MultiAgentPolicyConfigDict:

policy_config = policy_config or {}

# Try extracting spaces from env or from given spaces dict.
env_obs_space = None
env_act_space = None
Expand Down Expand Up @@ -1455,14 +1446,15 @@ def _determine_spaces_for_multi_agent_dict(
obs_space = spaces[pid][0]
elif env_obs_space is not None:
obs_space = env_obs_space
elif policy_config and policy_config.get("observation_space"):
elif policy_config.get("observation_space"):
obs_space = policy_config["observation_space"]
else:
raise ValueError(
"`observation_space` not provided in PolicySpec for "
f"{pid} and env does not have an observation space OR "
"no spaces received from other workers' env(s) OR no "
"`observation_space` specified in config!")

multi_agent_dict[pid] = multi_agent_dict[pid]._replace(
observation_space=obs_space)

Expand All @@ -1471,7 +1463,7 @@ def _determine_spaces_for_multi_agent_dict(
act_space = spaces[pid][1]
elif env_act_space is not None:
act_space = env_act_space
elif policy_config and policy_config.get("action_space"):
elif policy_config.get("action_space"):
act_space = policy_config["action_space"]
else:
raise ValueError(
Expand All @@ -1489,8 +1481,7 @@ def _validate_env(env: EnvType, env_context: EnvContext = None):
msg = f"Validating sub-env at vector index={env_context.vector_index} ..."

allowed_types = [
gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv,
ray.actor.ActorHandle
gym.Env, ExternalEnv, VectorEnv, BaseEnv, ray.actor.ActorHandle
]
if not any(isinstance(env, tpe) for tpe in allowed_types):
# Allow this as a special case (assumed gym.Env).
Expand All @@ -1508,7 +1499,7 @@ def _validate_env(env: EnvType, env_context: EnvContext = None):
f"(type={type(env)}).")

# Do some test runs with the provided env.
if isinstance(env, gym.Env):
if isinstance(env, gym.Env) and not isinstance(env, MultiAgentEnv):
# Make sure the gym.Env has the two space attributes properly set.
assert hasattr(env, "observation_space") and hasattr(
env, "action_space")
Expand Down
1 change: 0 additions & 1 deletion rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def valid_module(class_path):
normalize_actions=config["normalize_actions"],
clip_actions=config["clip_actions"],
env_config=config["env_config"],
model_config=config["model"],
policy_config=config,
worker_index=worker_index,
num_workers=num_workers,
Expand Down
29 changes: 3 additions & 26 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,8 @@
"attention_use_n_prev_rewards": 0,

# == Atari ==
# Which framestacking size to use for Atari envs.
# "auto": Use a value of 4, but only if the env is an Atari env.
# > 1: Use the trajectory view API in the default VisionNets to request the
# last n observations (single, grayscaled 84x84 image frames) as
# inputs. The time axis in the so provided observation tensors
# will come right after the batch axis (channels first format),
# e.g. BxTx84x84, where T=num_framestacks.
# 0 or 1: No framestacking used.
# Use the deprecated `framestack=True`, to disable the above behavor and to
# enable legacy stacking behavior (w/o trajectory view API) instead.
"num_framestacks": "auto",
# Set to True to enable 4x stacking behavior.
"framestack": True,
# Final resized frame dimension
"dim": 84,
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
Expand All @@ -166,8 +157,6 @@
# Deprecated keys:
# Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
"lstm_use_prev_action_reward": DEPRECATED_VALUE,
# Use `num_framestacks` (int) instead.
"framestack": True,
}
# __sphinx_doc_end__
# yapf: enable
Expand Down Expand Up @@ -807,10 +796,6 @@ def _get_v2_model_class(input_space: gym.Space,
"framework={} not supported in `ModelCatalog._get_v2_model_"
"class`!".format(framework))

# Discrete/1D obs-spaces or 2D obs space but traj. view framestacking
# disabled.
num_framestacks = model_config.get("num_framestacks", "auto")

# Tuple space, where at least one sub-space is image.
# -> Complex input model.
space_to_check = input_space if not hasattr(
Expand All @@ -824,8 +809,7 @@ def _get_v2_model_class(input_space: gym.Space,
# Single, flattenable/one-hot-able space -> Simple FCNet.
if isinstance(input_space, (Discrete, MultiDiscrete)) or \
len(input_space.shape) == 1 or (
len(input_space.shape) == 2 and (
num_framestacks == "auto" or num_framestacks <= 1)):
len(input_space.shape) == 2):
# Keras native requested AND no auto-rnn-wrapping.
if model_config.get("_use_default_native_models") and Keras_FCNet:
return Keras_FCNet
Expand Down Expand Up @@ -886,10 +870,3 @@ def _validate_config(config: ModelConfigDict, framework: str) -> None:
elif config.get("use_lstm"):
raise ValueError("`use_lstm` not available for "
"framework=jax so far!")

if config.get("framestack") != DEPRECATED_VALUE:
# deprecation_warning(
# old="framestack", new="num_framestacks (int)", error=False)
# If old behavior is desired, disable traj. view-style
# framestacking.
config["num_framestacks"] = 0
Loading