From bf3eaa9264e579e9cc84a641012154fcd803c06c Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 18 Aug 2021 18:47:08 +0200 Subject: [PATCH] [RLlib] Dreamer fixes and reinstate Dreamer test. (#17821) Co-authored-by: sven1977 --- rllib/BUILD | 15 +++--- rllib/agents/dreamer/dreamer.py | 49 +++----------------- rllib/agents/dreamer/dreamer_model.py | 12 +++-- rllib/agents/dreamer/dreamer_torch_policy.py | 49 ++++++++++++++++++-- rllib/agents/dreamer/tests/test_dreamer.py | 31 ++++++++----- rllib/agents/dreamer/utils.py | 4 ++ rllib/evaluation/sampler.py | 30 +++++------- 7 files changed, 102 insertions(+), 88 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 328d55cda9d8..71f5efc46b86 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -647,14 +647,13 @@ py_test( srcs = ["agents/dqn/tests/test_simple_q.py"] ) -# TODO: enable once we have a MuJoCo-independent test case. -## Dreamer -#py_test( -# name = "test_dreamer", -# tags = ["trainers_dir"], -# size = "small", -# srcs = ["agents/dreamer/tests/test_dreamer.py"] -#) +# Dreamer +py_test( + name = "test_dreamer", + tags = ["trainers_dir"], + size = "small", + srcs = ["agents/dreamer/tests/test_dreamer.py"] +) # ES py_test( diff --git a/rllib/agents/dreamer/dreamer.py b/rllib/agents/dreamer/dreamer.py index 8d783514f510..4a8170f52787 100644 --- a/rllib/agents/dreamer/dreamer.py +++ b/rllib/agents/dreamer/dreamer.py @@ -106,9 +106,6 @@ def add(self, batch: SampleBatchType): self.timesteps += batch.count episodes = batch.split_by_episode() - - for i, e in enumerate(episodes): - episodes[i] = self.preprocess_episode(e) self.episodes.extend(episodes) if len(self.episodes) > self.max_length: @@ -116,34 +113,6 @@ def add(self, batch: SampleBatchType): # Drop oldest episodes self.episodes = self.episodes[delta:] - def preprocess_episode(self, episode: SampleBatchType): - """Batch format should be in the form of (s_t, a_(t-1), r_(t-1)) - When t=0, the resetted obs is paired with action and reward of 0. - - Args: - episode: SampleBatch representing an episode - """ - obs = episode["obs"] - new_obs = episode["new_obs"] - action = episode["actions"] - reward = episode["rewards"] - - act_shape = action.shape - act_reset = np.array([0.0] * act_shape[-1])[None] - rew_reset = np.array(0.0)[None] - obs_end = np.array(new_obs[act_shape[0] - 1])[None] - - batch_obs = np.concatenate([obs, obs_end], axis=0) - batch_action = np.concatenate([act_reset, action], axis=0) - batch_rew = np.concatenate([rew_reset, reward], axis=0) - - new_batch = { - "obs": batch_obs, - "rewards": batch_rew, - "actions": batch_action - } - return SampleBatch(new_batch) - def sample(self, batch_size: int): """Samples [batch_size, length] from the list of episodes @@ -158,13 +127,9 @@ def sample(self, batch_size: int): continue available = episode.count - self.length index = int(random.randint(0, available)) - episodes_buffer.append(episode.slice(index, index + self.length)) - - batch = {} - for k in episodes_buffer[0].keys(): - batch[k] = np.stack([e[k] for e in episodes_buffer], axis=0) + episodes_buffer.append(episode[index:index + self.length]) - return SampleBatch(batch) + return SampleBatch.concat_samples(episodes_buffer) def total_sampled_timesteps(worker): @@ -182,12 +147,12 @@ def __init__(self, worker, episode_buffer, dreamer_train_iters, batch_size, def __call__(self, samples): - # Dreamer Training Loop + # Dreamer training loop. for n in range(self.dreamer_train_iters): - print(n) + print(f"sub-iteration={n}/{self.dreamer_train_iters}") batch = self.episode_buffer.sample(self.batch_size) - if n == self.dreamer_train_iters - 1: - batch["log_gif"] = True + # if n == self.dreamer_train_iters - 1: + # batch["log_gif"] = True fetches = self.worker.learn_on_batch(batch) # Custom Logging @@ -220,7 +185,7 @@ def policy_stats(self, fetches): def execution_plan(workers, config): - # Special Replay Buffer for Dreamer agent + # Special replay buffer for Dreamer agent. episode_buffer = EpisodicBuffer(length=config["batch_length"]) local_worker = workers.local_worker() diff --git a/rllib/agents/dreamer/dreamer_model.py b/rllib/agents/dreamer/dreamer_model.py index 0f35f58e8c0a..a509800ed645 100644 --- a/rllib/agents/dreamer/dreamer_model.py +++ b/rllib/agents/dreamer/dreamer_model.py @@ -111,7 +111,7 @@ def forward(self, x): orig_shape = list(x.size()) x = self.model(x) - reshape_size = orig_shape[:-1] + self.shape + reshape_size = orig_shape[:-1] + list(self.shape) mean = x.view(*reshape_size) # Equivalent to making a multivariate diag @@ -323,7 +323,7 @@ def observe(self, ) -> Tuple[List[TensorType], List[TensorType]]: """Returns the corresponding states from the embedding from ConvEncoder and actions. This is accomplished by rolling out the RNN from the - starting state through eacn index of embed and action, saving all + starting state through each index of embed and action, saving all intermediate states between. Args: @@ -337,6 +337,12 @@ def observe(self, if state is None: state = self.get_initial_state(action.size()[0]) + if embed.dim() <= 2: + embed = torch.unsqueeze(embed, 1) + + if action.dim() <= 2: + action = torch.unsqueeze(action, 1) + embed = embed.permute(1, 0, 2) action = action.permute(1, 0, 2) @@ -481,7 +487,7 @@ def policy(self, obs: TensorType, state: List[TensorType], explore=True and policy to obtain action. """ if state is None: - self.initial_state() + self.state = self.get_initial_state(batch_size=obs.shape[0]) else: self.state = state post = self.state[:4] diff --git a/rllib/agents/dreamer/dreamer_torch_policy.py b/rllib/agents/dreamer/dreamer_torch_policy.py index cc0c1e2a269c..28a0b3c77ed2 100644 --- a/rllib/agents/dreamer/dreamer_torch_policy.py +++ b/rllib/agents/dreamer/dreamer_torch_policy.py @@ -1,11 +1,18 @@ import logging +import numpy as np +from typing import Dict, Optional + import ray from ray.rllib.agents.dreamer.utils import FreezeParameters +from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import apply_grad_clipping +from ray.rllib.utils.typing import AgentID torch, nn = try_import_torch() if torch: @@ -174,7 +181,7 @@ def dreamer_loss(policy, model, dist_class, train_batch): def build_dreamer_model(policy, obs_space, action_space, config): - policy.model = ModelCatalog.get_model_v2( + model = ModelCatalog.get_model_v2( obs_space, action_space, 1, @@ -182,9 +189,9 @@ def build_dreamer_model(policy, obs_space, action_space, config): name="DreamerModel", framework="torch") - policy.model_variables = policy.model.variables() + policy.model_variables = model.variables() - return policy.model + return model def action_sampler_fn(policy, model, input_dict, state, explore, timestep): @@ -197,7 +204,7 @@ def action_sampler_fn(policy, model, input_dict, state, explore, timestep): # Custom Exploration if timestep <= policy.config["prefill_timesteps"]: - logp = [0.0] + logp = None # Random action in space [-1.0, 1.0] action = 2.0 * torch.rand(1, model.action_space.shape[0]) - 1.0 state = model.get_initial_state() @@ -236,11 +243,45 @@ def dreamer_optimizer_fn(policy, config): return (model_opt, actor_opt, critic_opt) +def preprocess_episode( + policy: Policy, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, + episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: + """Batch format should be in the form of (s_t, a_(t-1), r_(t-1)) + When t=0, the resetted obs is paired with action and reward of 0. + """ + obs = sample_batch[SampleBatch.OBS] + new_obs = sample_batch[SampleBatch.NEXT_OBS] + action = sample_batch[SampleBatch.ACTIONS] + reward = sample_batch[SampleBatch.REWARDS] + eps_ids = sample_batch[SampleBatch.EPS_ID] + + act_shape = action.shape + act_reset = np.array([0.0] * act_shape[-1])[None] + rew_reset = np.array(0.0)[None] + obs_end = np.array(new_obs[act_shape[0] - 1])[None] + + batch_obs = np.concatenate([obs, obs_end], axis=0) + batch_action = np.concatenate([act_reset, action], axis=0) + batch_rew = np.concatenate([rew_reset, reward], axis=0) + batch_eps_ids = np.concatenate([eps_ids, eps_ids[-1:]], axis=0) + + new_batch = { + SampleBatch.OBS: batch_obs, + SampleBatch.REWARDS: batch_rew, + SampleBatch.ACTIONS: batch_action, + SampleBatch.EPS_ID: batch_eps_ids, + } + return SampleBatch(new_batch) + + DreamerTorchPolicy = build_policy_class( name="DreamerTorchPolicy", framework="torch", get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG, action_sampler_fn=action_sampler_fn, + postprocess_fn=preprocess_episode, loss_fn=dreamer_loss, stats_fn=dreamer_stats, make_model=build_dreamer_model, diff --git a/rllib/agents/dreamer/tests/test_dreamer.py b/rllib/agents/dreamer/tests/test_dreamer.py index 2b318866ca48..08b70e8290e3 100644 --- a/rllib/agents/dreamer/tests/test_dreamer.py +++ b/rllib/agents/dreamer/tests/test_dreamer.py @@ -1,11 +1,10 @@ +from gym.spaces import Box import unittest import ray -from ray import tune import ray.rllib.agents.dreamer as dreamer -from ray.rllib.examples.env.dm_control_suite import hopper_hop -from ray.rllib.utils.test_utils import check_compute_single_action, \ - framework_iterator +from ray.rllib.examples.env.random_env import RandomEnv +from ray.rllib.utils.test_utils import framework_iterator class TestDreamer(unittest.TestCase): @@ -20,19 +19,27 @@ def tearDown(self): def test_dreamer_compilation(self): """Test whether an DreamerTrainer can be built with all frameworks.""" config = dreamer.DEFAULT_CONFIG.copy() - tune.register_env("dm_control_hopper_hop", lambda _: hopper_hop()) + config["env_config"] = { + "observation_space": Box(-1.0, 1.0, (3, 64, 64)), + "action_space": Box(-1.0, 1.0, (3, )) + } + # Num episode chunks per batch. + config["batch_size"] = 2 + # Length (ts) of an episode chunk in a batch. + config["batch_length"] = 20 + # Sub-iterations per .train() call. + config["dreamer_train_iters"] = 4 num_iterations = 1 # Test against all frameworks. for _ in framework_iterator(config, frameworks="torch"): - for env in ["dm_control_hopper_hop"]: - trainer = dreamer.DREAMERTrainer(config=config, env=env) - for i in range(num_iterations): - results = trainer.train() - print(results) - check_compute_single_action(trainer) - trainer.stop() + trainer = dreamer.DREAMERTrainer(config=config, env=RandomEnv) + for i in range(num_iterations): + results = trainer.train() + print(results) + # check_compute_single_action(trainer, include_state=True) + trainer.stop() if __name__ == "__main__": diff --git a/rllib/agents/dreamer/utils.py b/rllib/agents/dreamer/utils.py index f1ba7fd7049e..67f830797bbf 100644 --- a/rllib/agents/dreamer/utils.py +++ b/rllib/agents/dreamer/utils.py @@ -49,6 +49,10 @@ class TanhBijector(torch.distributions.Transform): def __init__(self): super().__init__() + self.bijective = True + self.domain = torch.distributions.constraints.real + self.codomain = torch.distributions.constraints.interval(-1.0, 1.0) + def atanh(self, x): return 0.5 * torch.log((1 + x) / (1 - x)) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 5737c7bd21f8..d4f2f50aa2b8 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -217,9 +217,8 @@ def __init__( self.render = render # Create the rollout generator to use for calls to `get_data()`. - self.rollout_provider = _env_runner( - worker, self.base_env, self.extra_batches.put, - self.rollout_fragment_length, self.horizon, clip_rewards, + self._env_runner = _env_runner( + worker, self.base_env, self.extra_batches.put, self.horizon, normalize_actions, clip_actions, multiple_episodes_in_batch, callbacks, self.perf_stats, soft_horizon, no_done_at_end, observation_fn, self.sample_collector, self.render) @@ -228,7 +227,7 @@ def __init__( @override(SamplerInput) def get_data(self) -> SampleBatchType: while True: - item = next(self.rollout_provider) + item = next(self._env_runner) if isinstance(item, RolloutMetrics): self.metrics_queue.put(item) else: @@ -372,11 +371,11 @@ def __init__( if not sample_collector_class: sample_collector_class = SimpleListCollector self.sample_collector = sample_collector_class( - worker.policy_map, - clip_rewards, - callbacks, - multiple_episodes_in_batch, - rollout_fragment_length, + self.worker.policy_map, + self.clip_rewards, + self.callbacks, + self.multiple_episodes_in_batch, + self.rollout_fragment_length, count_steps_by=count_steps_by) @override(threading.Thread) @@ -395,9 +394,8 @@ def _run(self): queue_putter = self.queue.put extra_batches_putter = ( lambda x: self.extra_batches.put(x, timeout=600.0)) - rollout_provider = _env_runner( - self.worker, self.base_env, extra_batches_putter, - self.rollout_fragment_length, self.horizon, self.clip_rewards, + env_runner = _env_runner( + self.worker, self.base_env, extra_batches_putter, self.horizon, self.normalize_actions, self.clip_actions, self.multiple_episodes_in_batch, self.callbacks, self.perf_stats, self.soft_horizon, self.no_done_at_end, self.observation_fn, @@ -406,7 +404,7 @@ def _run(self): # The timeout variable exists because apparently, if one worker # dies, the other workers won't die with it, unless the timeout is # set to some large number. This is an empirical observation. - item = next(rollout_provider) + item = next(env_runner) if isinstance(item, RolloutMetrics): self.metrics_queue.put(item) else: @@ -450,9 +448,7 @@ def _env_runner( worker: "RolloutWorker", base_env: BaseEnv, extra_batch_callback: Callable[[SampleBatchType], None], - rollout_fragment_length: int, horizon: int, - clip_rewards: bool, normalize_actions: bool, clip_actions: bool, multiple_episodes_in_batch: bool, @@ -470,11 +466,7 @@ def _env_runner( worker (RolloutWorker): Reference to the current rollout worker. base_env (BaseEnv): Env implementing BaseEnv. extra_batch_callback (fn): function to send extra batch data to. - rollout_fragment_length (int): Number of episode steps before - `SampleBatch` is yielded. Set to infinity to yield complete - episodes. horizon (int): Horizon of the episode. - clip_rewards (bool): Whether to clip rewards before postprocessing. multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size.