Skip to content

Commit

Permalink
[rllib] Properly flatten 2-d observations as input to FCnet (ray-proj…
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl committed Sep 19, 2019
1 parent 7131166 commit 6da7eff
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 6 deletions.
6 changes: 3 additions & 3 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def _get_default_torch_model_v2(obs_space, action_space, num_outputs,
else:
obs_rank = len(obs_space.shape)

if obs_rank > 1:
if obs_rank > 2:
return PyTorchVisionNet(obs_space, action_space, num_outputs,
model_config, name)

Expand Down Expand Up @@ -506,7 +506,7 @@ def _get_model(input_dict, obs_space, action_space, num_outputs, options,

obs_rank = len(input_dict["obs"].shape) - 1

if obs_rank > 1:
if obs_rank > 2:
return VisionNetwork(input_dict, obs_space, action_space,
num_outputs, options)

Expand All @@ -521,7 +521,7 @@ def _get_v2_model(obs_space, options):
if options.get("use_lstm"):
return None # TODO: default LSTM v2 not implemented

if obs_rank > 1:
if obs_rank > 2:
return VisionNetV2

return FCNetV2
Expand Down
12 changes: 12 additions & 0 deletions rllib/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,18 @@ def _validate_output_shape(self):
self._num_outputs, shape))


@DeveloperAPI
def flatten(obs, framework):
"""Flatten the given tensor."""
if framework == "tf":
return tf.layers.flatten(obs)
elif framework == "torch":
import torch
return torch.flatten(obs, start_dim=1)
else:
raise NotImplementedError("flatten", framework)


@DeveloperAPI
def restore_original_dimensions(obs, obs_space, tensorlib=tf):
"""Unpacks Dict and Tuple space observations into their original form.
Expand Down
7 changes: 5 additions & 2 deletions rllib/models/modelv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import print_function

from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.model import restore_original_dimensions
from ray.rllib.models.model import restore_original_dimensions, flatten
from ray.rllib.utils.annotations import PublicAPI


Expand Down Expand Up @@ -146,7 +146,10 @@ def __call__(self, input_dict, state=None, seq_lens=None):
restored = input_dict.copy()
restored["obs"] = restore_original_dimensions(
input_dict["obs"], self.obs_space, self.framework)
restored["obs_flat"] = input_dict["obs"]
if len(input_dict["obs"].shape) > 2:
restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
else:
restored["obs_flat"] = input_dict["obs"]
with self.context():
res = self.forward(restored, state or [], seq_lens)
if ((not isinstance(res, list) and not isinstance(res, tuple))
Expand Down
3 changes: 3 additions & 0 deletions rllib/models/tf/fcnet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def _build_layers(self, inputs, num_outputs, options):
hiddens = options.get("fcnet_hiddens")
activation = get_activation_fn(options.get("fcnet_activation"))

if len(inputs.shape) > 2:
inputs = tf.layers.flatten(inputs)

with tf.name_scope("fc_net"):
i = 1
last_layer = inputs
Expand Down
5 changes: 4 additions & 1 deletion rllib/models/tf/fcnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from __future__ import division
from __future__ import print_function

import numpy as np

from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.misc import normc_initializer, get_activation_fn
from ray.rllib.utils import try_import_tf
Expand All @@ -22,8 +24,9 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
no_final_linear = model_config.get("no_final_linear")
vf_share_layers = model_config.get("vf_share_layers")

# we are using obs_flat, so take the flattened shape as input
inputs = tf.keras.layers.Input(
shape=obs_space.shape, name="observations")
shape=(np.product(obs_space.shape), ), name="observations")
last_layer = inputs
i = 1

Expand Down
2 changes: 2 additions & 0 deletions rllib/tests/test_supported_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
OBSERVATION_SPACES_TO_TEST = {
"discrete": Discrete(5),
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
"vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
"image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32),
"atari": Box(-1.0, 1.0, (210, 160, 3), dtype=np.float32),
"tuple": Tuple([Discrete(10),
Expand Down Expand Up @@ -106,6 +107,7 @@ def check_support(alg, config, stats, check_bounds=False, name=None):
def check_support_multiagent(alg, config):
register_env("multi_mountaincar", lambda _: MultiMountainCar(2))
register_env("multi_cartpole", lambda _: MultiCartpole(2))
config["log_level"] = "ERROR"
if "DDPG" in alg:
a = get_agent_class(alg)(config=config, env="multi_mountaincar")
else:
Expand Down

0 comments on commit 6da7eff

Please sign in to comment.