Skip to content

Commit

Permalink
[RLlib] Issue ray-project#13342: Add validate_spaces to MB-MPO. (ra…
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Feb 11, 2021
1 parent f6cfc44 commit a2f7998
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
31 changes: 31 additions & 0 deletions rllib/agents/mbmpo/mbmpo_torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gym
from gym.spaces import Box, Discrete
import logging
from typing import Tuple, Type

Expand All @@ -13,6 +14,7 @@
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping
from ray.rllib.utils.typing import TrainerConfigDict
Expand All @@ -22,6 +24,35 @@
logger = logging.getLogger(__name__)


def validate_spaces(policy: Policy, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
"""Validates the observation- and action spaces used for the Policy.
Args:
policy (Policy): The policy, whose spaces are being validated.
observation_space (gym.spaces.Space): The observation space to
validate.
action_space (gym.spaces.Space): The action space to validate.
config (TrainerConfigDict): The Policy's config dict.
Raises:
UnsupportedSpaceException: If one of the spaces is not supported.
"""
# Only support single Box or single Discrete spaces.
if not isinstance(action_space, (Box, Discrete)):
raise UnsupportedSpaceException(
"Action space ({}) of {} is not supported for "
"MB-MPO. Must be [Box|Discrete].".format(action_space, policy))
# If Box, make sure it's a 1D vector space.
elif isinstance(action_space, Box) and len(action_space.shape) > 1:
raise UnsupportedSpaceException(
"Action space ({}) of {} has multiple dimensions "
"{}. ".format(action_space, policy, action_space.shape) +
"Consider reshaping this into a single dimension Box space "
"or using the multi-agent API.")


def make_model_and_action_dist(
policy: Policy,
obs_space: gym.spaces.Space,
Expand Down
2 changes: 2 additions & 0 deletions rllib/agents/mbmpo/model_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
obs_space.low[0],
obs_space.high[0],
shape=(obs_space.shape[0] + action_space.shape[0], ))
else:
raise NotImplementedError
super(DynamicsEnsembleCustomModel, self).__init__(
input_space, action_space, num_outputs, model_config, name)

Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/sac/sac_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def validate_spaces(policy: Policy, observation_space: gym.spaces.Space,
Raises:
UnsupportedSpaceException: If one of the spaces is not supported.
"""
# Only support single Box or single Discreete spaces.
# Only support single Box or single Discrete spaces.
if not isinstance(action_space, (Box, Discrete, Simplex)):
raise UnsupportedSpaceException(
"Action space ({}) of {} is not supported for "
Expand Down

0 comments on commit a2f7998

Please sign in to comment.