Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix Atari learning test regressions (2 bugs) and 1 minor attention net bug. #18306

Merged
merged 12 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions rllib/agents/pg/tests/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@


class TestPG(unittest.TestCase):
def setUp(self):
@classmethod
def setUpClass(cls) -> None:
ray.init()

def tearDown(self):
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()

def test_pg_compilation(self):
Expand Down
4 changes: 3 additions & 1 deletion rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,10 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
d.itemsize * int(np.product(d.shape[i + 1:]))
for i in range(1, len(d.shape))
]
start = self.shift_before - shift_win + 1 + obs_shift + \
view_req.shift_to
data = np.lib.stride_tricks.as_strided(
d[self.shift_before - shift_win:],
d[start:start + self.agent_steps],
[self.agent_steps, shift_win
] + [d.shape[i] for i in range(1, len(d.shape))],
[data_size, data_size] + strides)
Expand Down
26 changes: 15 additions & 11 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ray.rllib.utils import force_list, merge_dicts
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.error import EnvError
from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
Expand Down Expand Up @@ -399,7 +399,8 @@ def gen_rollouts():
self.callbacks: DefaultCallbacks = DefaultCallbacks()
self.worker_index: int = worker_index
self.num_workers: int = num_workers
model_config: ModelConfigDict = model_config or {}
model_config: ModelConfigDict = \
model_config or self.policy_config.get("model") or {}

# Default policy mapping fn is to always return DEFAULT_POLICY_ID,
# independent on the agent ID and the episode passed in.
Expand Down Expand Up @@ -486,15 +487,14 @@ def wrap(env):
clip_rewards = True

# Deprecated way of framestacking is used.
framestack = model_config.get("framestack") is True
use_old_framestack = model_config.get("framestack") is True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway to say this is deprecated in the logs?

# framestacking via trajectory view API is enabled.
num_framestacks = model_config.get("num_framestacks", 0)

# Trajectory view API is on and num_framestacks=auto:
# Only stack traj. view based if old
# `framestack=[invalid value]`.
# num_framestacks=auto:
# Only stack traj. view based if `framestack=[invalid value]`.
if num_framestacks == "auto":
if framestack == DEPRECATED_VALUE:
if not use_old_framestack:
model_config["num_framestacks"] = num_framestacks = 4
else:
model_config["num_framestacks"] = num_framestacks = 0
Expand All @@ -504,7 +504,7 @@ def wrap(env):
env = wrap_deepmind(
env,
dim=model_config.get("dim"),
framestack=framestack,
framestack=use_old_framestack,
framestack_via_traj_view_api=framestack_traj_view)
env = record_env_wrapper(env, record_env, log_dir,
policy_config)
Expand Down Expand Up @@ -761,7 +761,8 @@ def sample(self) -> SampleBatchType:
return self.last_batch
elif self.input_reader is None:
raise ValueError("RolloutWorker has no `input_reader` object! "
"Cannot call `sample()`.")
"Cannot call `sample()`. You can try setting "
"`create_env_on_driver` to True.")

if log_once("sample_start"):
logger.info("Generating sample batch of size {}".format(
Expand Down Expand Up @@ -1444,6 +1445,8 @@ def _determine_spaces_for_multi_agent_dict(
policy_config: Optional[PartialTrainerConfigDict] = None,
) -> MultiAgentPolicyConfigDict:

policy_config = policy_config or {}

# Try extracting spaces from env or from given spaces dict.
env_obs_space = None
env_act_space = None
Expand Down Expand Up @@ -1476,14 +1479,15 @@ def _determine_spaces_for_multi_agent_dict(
obs_space = spaces[pid][0]
elif env_obs_space is not None:
obs_space = env_obs_space
elif policy_config and policy_config.get("observation_space"):
elif policy_config.get("observation_space"):
obs_space = policy_config["observation_space"]
else:
raise ValueError(
"`observation_space` not provided in PolicySpec for "
f"{pid} and env does not have an observation space OR "
"no spaces received from other workers' env(s) OR no "
"`observation_space` specified in config!")

multi_agent_dict[pid] = multi_agent_dict[pid]._replace(
observation_space=obs_space)

Expand All @@ -1492,7 +1496,7 @@ def _determine_spaces_for_multi_agent_dict(
act_space = spaces[pid][1]
elif env_act_space is not None:
act_space = env_act_space
elif policy_config and policy_config.get("action_space"):
elif policy_config.get("action_space"):
act_space = policy_config["action_space"]
else:
raise ValueError(
Expand Down
1 change: 0 additions & 1 deletion rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def valid_module(class_path):
normalize_actions=config["normalize_actions"],
clip_actions=config["clip_actions"],
env_config=config["env_config"],
model_config=config["model"],
policy_config=config,
worker_index=worker_index,
num_workers=num_workers,
Expand Down
9 changes: 5 additions & 4 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,16 @@
# == Atari ==
# Which framestacking size to use for Atari envs.
# "auto": Use a value of 4, but only if the env is an Atari env.
# > 1: Use the trajectory view API in the default VisionNets to request the
# > 1: Use the trajectory view API via the default VisionNets to request
# last n observations (single, grayscaled 84x84 image frames) as
# inputs. The time axis in the so provided observation tensors
# will come right after the batch axis (channels first format),
# e.g. BxTx84x84, where T=num_framestacks.
# 0 or 1: No framestacking used.
# Use the deprecated `framestack=True`, to disable the above behavor and to
# enable legacy stacking behavior (w/o trajectory view API) instead.
"num_framestacks": "auto",
# Use the (deprecated) `framestack=True`, to disable the above behavior
# and to enable legacy stacking behavior (w/o using trajectory view API)
# instead.
"num_framestacks": 0,
# Final resized frame dimension
"dim": 84,
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
Expand Down
12 changes: 9 additions & 3 deletions rllib/models/tf/visionnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def __init__(self, obs_space: gym.spaces.Space,
layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs]
if post_fcnet_hiddens else
[])
feature_out = last_layer

for i, out_size in enumerate(layer_sizes):
feature_out = last_layer
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i),
Expand Down Expand Up @@ -126,6 +129,7 @@ def __init__(self, obs_space: gym.spaces.Space,
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(post_fcnet_hiddens[1:] +
[num_outputs]):
feature_out = last_layer
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i + 1),
Expand All @@ -134,6 +138,7 @@ def __init__(self, obs_space: gym.spaces.Space,
kernel_initializer=normc_initializer(1.0))(
last_layer)
else:
feature_out = last_layer
last_cnn = last_layer = tf.keras.layers.Conv2D(
num_outputs, [1, 1],
activation=None,
Expand Down Expand Up @@ -164,19 +169,20 @@ def __init__(self, obs_space: gym.spaces.Space,
name="post_fcnet_{}".format(i),
activation=post_fcnet_activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
feature_out = last_layer
self.num_outputs = last_layer.shape[1]
logits_out = last_layer

# Build the value layers
if vf_share_layers:
if not self.last_layer_is_flattened:
last_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
feature_out = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(feature_out)
value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
kernel_initializer=normc_initializer(0.01))(feature_out)
else:
# build a parallel set of hidden layers for the value net
last_layer = inputs
Expand Down
1 change: 1 addition & 0 deletions rllib/models/torch/visionnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def __init__(self, obs_space: gym.spaces.Space,
self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
data_col=SampleBatch.OBS,
shift="-{}:1".format(from_ - 1),
used_for_compute_actions=False,
space=self.view_requirements[SampleBatch.OBS].space,
)

Expand Down
37 changes: 22 additions & 15 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,21 +942,28 @@ def get_single_step_input_dict(self, view_requirements, index="last"):
data_col = last_mappings.get(data_col, data_col)
# Range needed.
if view_req.shift_from is not None:
data = self[view_col][-1]
traj_len = len(self[data_col])
missing_at_end = traj_len % view_req.batch_repeat_value
obs_shift = -1 if data_col in [
SampleBatch.OBS, SampleBatch.NEXT_OBS
] else 0
from_ = view_req.shift_from + obs_shift
to_ = view_req.shift_to + obs_shift + 1
if to_ == 0:
to_ = None
input_dict[view_col] = np.array([
np.concatenate(
[data,
self[data_col][-missing_at_end:]])[from_:to_]
])
# Batch repeat value > 1: We have single frames in the
# batch at each timestep.
if view_req.batch_repeat_value > 1:
data = self[view_col][-1]
traj_len = len(self[data_col])
missing_at_end = traj_len % view_req.batch_repeat_value
obs_shift = -1 if data_col in [
SampleBatch.OBS, SampleBatch.NEXT_OBS
] else 0
from_ = view_req.shift_from + obs_shift
to_ = view_req.shift_to + obs_shift + 1
if to_ == 0:
to_ = None
input_dict[view_col] = np.array([
np.concatenate(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data has to be last in the concat, otherwise, e.g. an attention net will not necessarily see the most recent observations. This explains the learning enhancements on the RepeatAfterMe experiments vs older versions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

[self[data_col][-missing_at_end:],
data])[from_:to_]
])
# Batch repeat value = 1: We already have framestacks
# at each timestep.
else:
input_dict[view_col] = self[data_col][-1][None]
# Single index.
else:
data = self[data_col][-1]
Expand Down