Skip to content

Commit

Permalink
Merge branch 'master' of github.com:vitchyr/rlkit
Browse files Browse the repository at this point in the history
  • Loading branch information
vitchyr committed Mar 1, 2019
2 parents 78552d3 + c6d8821 commit 86db9c2
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 5 deletions.
73 changes: 73 additions & 0 deletions examples/her/her_dqn_gridworld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
This should results in an average return of -20 by the end of training.
Usually hits -30 around epoch 50.
Note that one epoch = 5k steps, so 200 epochs = 1 million steps.
"""
import gym

import rlkit.torch.pytorch_util as ptu
from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer
from rlkit.exploration_strategies.base import (
PolicyWrappedWithExplorationStrategy
)
from rlkit.exploration_strategies.gaussian_and_epsilon_strategy import (
GaussianAndEpislonStrategy
)
from rlkit.launchers.launcher_util import setup_logger
from rlkit.torch.her.her import HerDQN
from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy
import multiworld.envs.gridworlds


def experiment(variant):
env = gym.make('GoalGridworld-v0')

obs_dim = env.observation_space.spaces['observation'].low.size
goal_dim = env.observation_space.spaces['desired_goal'].low.size
action_dim = env.action_space.n
qf1 = FlattenMlp(
input_size=obs_dim + goal_dim,
output_size=action_dim,
hidden_sizes=[400, 300],
)


replay_buffer = ObsDictRelabelingBuffer(
env=env,
**variant['replay_buffer_kwargs']
)
algorithm = HerDQN(
her_kwargs=dict(
observation_key='observation',
desired_goal_key='desired_goal'
),
dqn_kwargs = dict(
env=env,
qf=qf1,
),
replay_buffer=replay_buffer,
**variant['algo_kwargs']
)
algorithm.to(ptu.device)
algorithm.train()


if __name__ == "__main__":
variant = dict(
algo_kwargs=dict(
num_epochs=100,
num_steps_per_epoch=1000,
num_steps_per_eval=1000,
max_path_length=50,
batch_size=128,
discount=0.99,
),
replay_buffer_kwargs=dict(
max_size=100000,
fraction_goals_rollout_goals=0.2, # equal to k = 4 in HER paper
fraction_goals_env_goals=0.0,
),
)
setup_logger('her-dqn-gridworld-experiment', variant=variant)
experiment(variant)
10 changes: 6 additions & 4 deletions rlkit/data_management/obs_dict_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ def __init__(
self.observation_key = observation_key
self.desired_goal_key = desired_goal_key
self.achieved_goal_key = achieved_goal_key

self._action_dim = env.action_space.low.size
if isinstance(self.env.action_space, Discrete):
self._action_dim = env.action_space.n
else:
self._action_dim = env.action_space.low.size
self._actions = np.zeros((max_size, self._action_dim))
# self._terminals[i] = a terminal was received at time i
self._terminals = np.zeros((max_size, 1), dtype='uint8')
Expand All @@ -93,8 +95,6 @@ def __init__(
# Then self._next_obs[j] is a valid next observation for observation i
self._idx_to_future_obs_idx = [None] * max_size

if isinstance(self.env.action_space, Discrete):
raise NotImplementedError("TODO. See issue 28.")

def add_sample(self, observation, action, reward, terminal,
next_observation, **kwargs):
Expand All @@ -115,6 +115,8 @@ def add_path(self, path):
path_len = len(rewards)

actions = flatten_n(actions)
if isinstance(self.env.action_space, Discrete):
actions = np.eye(self._action_dim)[actions].reshape((-1, self._action_dim))
obs = flatten_dict(obs, self.ob_keys_to_save + self.internal_keys)
next_obs = flatten_dict(next_obs,
self.ob_keys_to_save + self.internal_keys)
Expand Down
1 change: 0 additions & 1 deletion rlkit/samplers/rollout_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def multitask_rollout(
if d:
break
o = next_o
full_observations.append(o)
actions = np.array(actions)
if len(actions.shape) == 1:
actions = np.expand_dims(actions, 1)
Expand Down
17 changes: 17 additions & 0 deletions rlkit/torch/her/her.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rlkit.torch.sac.sac import SoftActorCritic
from rlkit.torch.sac.twin_sac import TwinSAC
from rlkit.torch.td3.td3 import TD3
from rlkit.torch.dqn.dqn import DQN
from rlkit.torch.torch_rl_algorithm import TorchRLAlgorithm


Expand Down Expand Up @@ -194,3 +195,19 @@ def __init__(
) or isinstance(
self.replay_buffer, ObsDictRelabelingBuffer
)

class HerDQN(HER, DQN):
def __init__(
self,
*args,
her_kwargs,
dqn_kwargs,
**kwargs
):
HER.__init__(self, **her_kwargs)
DQN.__init__(self, *args, **kwargs, **dqn_kwargs)
assert isinstance(
self.replay_buffer, RelabelingReplayBuffer
) or isinstance(
self.replay_buffer, ObsDictRelabelingBuffer
)

0 comments on commit 86db9c2

Please sign in to comment.