diff --git a/examples/her/her_dqn_gridworld.py b/examples/her/her_dqn_gridworld.py new file mode 100644 index 000000000..eb56e9adc --- /dev/null +++ b/examples/her/her_dqn_gridworld.py @@ -0,0 +1,73 @@ +""" +This should results in an average return of -20 by the end of training. + +Usually hits -30 around epoch 50. +Note that one epoch = 5k steps, so 200 epochs = 1 million steps. +""" +import gym + +import rlkit.torch.pytorch_util as ptu +from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer +from rlkit.exploration_strategies.base import ( + PolicyWrappedWithExplorationStrategy +) +from rlkit.exploration_strategies.gaussian_and_epsilon_strategy import ( + GaussianAndEpislonStrategy +) +from rlkit.launchers.launcher_util import setup_logger +from rlkit.torch.her.her import HerDQN +from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy +import multiworld.envs.gridworlds + + +def experiment(variant): + env = gym.make('GoalGridworld-v0') + + obs_dim = env.observation_space.spaces['observation'].low.size + goal_dim = env.observation_space.spaces['desired_goal'].low.size + action_dim = env.action_space.n + qf1 = FlattenMlp( + input_size=obs_dim + goal_dim, + output_size=action_dim, + hidden_sizes=[400, 300], + ) + + + replay_buffer = ObsDictRelabelingBuffer( + env=env, + **variant['replay_buffer_kwargs'] + ) + algorithm = HerDQN( + her_kwargs=dict( + observation_key='observation', + desired_goal_key='desired_goal' + ), + dqn_kwargs = dict( + env=env, + qf=qf1, + ), + replay_buffer=replay_buffer, + **variant['algo_kwargs'] + ) + algorithm.to(ptu.device) + algorithm.train() + + +if __name__ == "__main__": + variant = dict( + algo_kwargs=dict( + num_epochs=100, + num_steps_per_epoch=1000, + num_steps_per_eval=1000, + max_path_length=50, + batch_size=128, + discount=0.99, + ), + replay_buffer_kwargs=dict( + max_size=100000, + fraction_goals_rollout_goals=0.2, # equal to k = 4 in HER paper + fraction_goals_env_goals=0.0, + ), + ) + setup_logger('her-dqn-gridworld-experiment', variant=variant) + experiment(variant) diff --git a/rlkit/data_management/obs_dict_replay_buffer.py b/rlkit/data_management/obs_dict_replay_buffer.py index f7a69763c..bae897d03 100644 --- a/rlkit/data_management/obs_dict_replay_buffer.py +++ b/rlkit/data_management/obs_dict_replay_buffer.py @@ -66,8 +66,10 @@ def __init__( self.observation_key = observation_key self.desired_goal_key = desired_goal_key self.achieved_goal_key = achieved_goal_key - - self._action_dim = env.action_space.low.size + if isinstance(self.env.action_space, Discrete): + self._action_dim = env.action_space.n + else: + self._action_dim = env.action_space.low.size self._actions = np.zeros((max_size, self._action_dim)) # self._terminals[i] = a terminal was received at time i self._terminals = np.zeros((max_size, 1), dtype='uint8') @@ -93,8 +95,6 @@ def __init__( # Then self._next_obs[j] is a valid next observation for observation i self._idx_to_future_obs_idx = [None] * max_size - if isinstance(self.env.action_space, Discrete): - raise NotImplementedError("TODO. See issue 28.") def add_sample(self, observation, action, reward, terminal, next_observation, **kwargs): @@ -115,6 +115,8 @@ def add_path(self, path): path_len = len(rewards) actions = flatten_n(actions) + if isinstance(self.env.action_space, Discrete): + actions = np.eye(self._action_dim)[actions].reshape((-1, self._action_dim)) obs = flatten_dict(obs, self.ob_keys_to_save + self.internal_keys) next_obs = flatten_dict(next_obs, self.ob_keys_to_save + self.internal_keys) diff --git a/rlkit/samplers/rollout_functions.py b/rlkit/samplers/rollout_functions.py index 2701efcef..217fe443b 100644 --- a/rlkit/samplers/rollout_functions.py +++ b/rlkit/samplers/rollout_functions.py @@ -68,7 +68,6 @@ def multitask_rollout( if d: break o = next_o - full_observations.append(o) actions = np.array(actions) if len(actions.shape) == 1: actions = np.expand_dims(actions, 1) diff --git a/rlkit/torch/her/her.py b/rlkit/torch/her/her.py index fa5f92c3e..e7518037f 100644 --- a/rlkit/torch/her/her.py +++ b/rlkit/torch/her/her.py @@ -9,6 +9,7 @@ from rlkit.torch.sac.sac import SoftActorCritic from rlkit.torch.sac.twin_sac import TwinSAC from rlkit.torch.td3.td3 import TD3 +from rlkit.torch.dqn.dqn import DQN from rlkit.torch.torch_rl_algorithm import TorchRLAlgorithm @@ -194,3 +195,19 @@ def __init__( ) or isinstance( self.replay_buffer, ObsDictRelabelingBuffer ) + +class HerDQN(HER, DQN): + def __init__( + self, + *args, + her_kwargs, + dqn_kwargs, + **kwargs + ): + HER.__init__(self, **her_kwargs) + DQN.__init__(self, *args, **kwargs, **dqn_kwargs) + assert isinstance( + self.replay_buffer, RelabelingReplayBuffer + ) or isinstance( + self.replay_buffer, ObsDictRelabelingBuffer + )