Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Dreamer fixes and reinstate Dreamer test. #17821

Merged
merged 10 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -647,14 +647,13 @@ py_test(
srcs = ["agents/dqn/tests/test_simple_q.py"]
)

# TODO: enable once we have a MuJoCo-independent test case.
## Dreamer
#py_test(
# name = "test_dreamer",
# tags = ["trainers_dir"],
# size = "small",
# srcs = ["agents/dreamer/tests/test_dreamer.py"]
#)
# Dreamer
py_test(
name = "test_dreamer",
tags = ["trainers_dir"],
size = "small",
srcs = ["agents/dreamer/tests/test_dreamer.py"]
)

# ES
py_test(
Expand Down
49 changes: 7 additions & 42 deletions rllib/agents/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,44 +106,13 @@ def add(self, batch: SampleBatchType):

self.timesteps += batch.count
episodes = batch.split_by_episode()

for i, e in enumerate(episodes):
episodes[i] = self.preprocess_episode(e)
self.episodes.extend(episodes)

if len(self.episodes) > self.max_length:
delta = len(self.episodes) - self.max_length
# Drop oldest episodes
self.episodes = self.episodes[delta:]

def preprocess_episode(self, episode: SampleBatchType):
"""Batch format should be in the form of (s_t, a_(t-1), r_(t-1))
When t=0, the resetted obs is paired with action and reward of 0.

Args:
episode: SampleBatch representing an episode
"""
obs = episode["obs"]
new_obs = episode["new_obs"]
action = episode["actions"]
reward = episode["rewards"]

act_shape = action.shape
act_reset = np.array([0.0] * act_shape[-1])[None]
rew_reset = np.array(0.0)[None]
obs_end = np.array(new_obs[act_shape[0] - 1])[None]

batch_obs = np.concatenate([obs, obs_end], axis=0)
batch_action = np.concatenate([act_reset, action], axis=0)
batch_rew = np.concatenate([rew_reset, reward], axis=0)

new_batch = {
"obs": batch_obs,
"rewards": batch_rew,
"actions": batch_action
}
return SampleBatch(new_batch)

def sample(self, batch_size: int):
"""Samples [batch_size, length] from the list of episodes

Expand All @@ -158,13 +127,9 @@ def sample(self, batch_size: int):
continue
available = episode.count - self.length
index = int(random.randint(0, available))
episodes_buffer.append(episode.slice(index, index + self.length))

batch = {}
for k in episodes_buffer[0].keys():
batch[k] = np.stack([e[k] for e in episodes_buffer], axis=0)
episodes_buffer.append(episode[index:index + self.length])

return SampleBatch(batch)
return SampleBatch.concat_samples(episodes_buffer)


def total_sampled_timesteps(worker):
Expand All @@ -182,12 +147,12 @@ def __init__(self, worker, episode_buffer, dreamer_train_iters, batch_size,

def __call__(self, samples):

# Dreamer Training Loop
# Dreamer training loop.
for n in range(self.dreamer_train_iters):
print(n)
print(f"sub-iteration={n}/{self.dreamer_train_iters}")
batch = self.episode_buffer.sample(self.batch_size)
if n == self.dreamer_train_iters - 1:
batch["log_gif"] = True
# if n == self.dreamer_train_iters - 1:
# batch["log_gif"] = True
fetches = self.worker.learn_on_batch(batch)

# Custom Logging
Expand Down Expand Up @@ -220,7 +185,7 @@ def policy_stats(self, fetches):


def execution_plan(workers, config):
# Special Replay Buffer for Dreamer agent
# Special replay buffer for Dreamer agent.
episode_buffer = EpisodicBuffer(length=config["batch_length"])

local_worker = workers.local_worker()
Expand Down
12 changes: 9 additions & 3 deletions rllib/agents/dreamer/dreamer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def forward(self, x):
orig_shape = list(x.size())
x = self.model(x)

reshape_size = orig_shape[:-1] + self.shape
reshape_size = orig_shape[:-1] + list(self.shape)
mean = x.view(*reshape_size)

# Equivalent to making a multivariate diag
Expand Down Expand Up @@ -323,7 +323,7 @@ def observe(self,
) -> Tuple[List[TensorType], List[TensorType]]:
"""Returns the corresponding states from the embedding from ConvEncoder
and actions. This is accomplished by rolling out the RNN from the
starting state through eacn index of embed and action, saving all
starting state through each index of embed and action, saving all
intermediate states between.

Args:
Expand All @@ -337,6 +337,12 @@ def observe(self,
if state is None:
state = self.get_initial_state(action.size()[0])

if embed.dim() <= 2:
embed = torch.unsqueeze(embed, 1)

if action.dim() <= 2:
action = torch.unsqueeze(action, 1)

embed = embed.permute(1, 0, 2)
action = action.permute(1, 0, 2)

Expand Down Expand Up @@ -481,7 +487,7 @@ def policy(self, obs: TensorType, state: List[TensorType], explore=True
and policy to obtain action.
"""
if state is None:
self.initial_state()
self.state = self.get_initial_state(batch_size=obs.shape[0])
else:
self.state = state
post = self.state[:4]
Expand Down
49 changes: 45 additions & 4 deletions rllib/agents/dreamer/dreamer_torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import logging

import numpy as np
from typing import Dict, Optional

import ray
from ray.rllib.agents.dreamer.utils import FreezeParameters
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping
from ray.rllib.utils.typing import AgentID

torch, nn = try_import_torch()
if torch:
Expand Down Expand Up @@ -174,17 +181,17 @@ def dreamer_loss(policy, model, dist_class, train_batch):

def build_dreamer_model(policy, obs_space, action_space, config):

policy.model = ModelCatalog.get_model_v2(
model = ModelCatalog.get_model_v2(
obs_space,
action_space,
1,
config["dreamer_model"],
name="DreamerModel",
framework="torch")

policy.model_variables = policy.model.variables()
policy.model_variables = model.variables()

return policy.model
return model


def action_sampler_fn(policy, model, input_dict, state, explore, timestep):
Expand All @@ -197,7 +204,7 @@ def action_sampler_fn(policy, model, input_dict, state, explore, timestep):

# Custom Exploration
if timestep <= policy.config["prefill_timesteps"]:
logp = [0.0]
logp = None
# Random action in space [-1.0, 1.0]
action = 2.0 * torch.rand(1, model.action_space.shape[0]) - 1.0
state = model.get_initial_state()
Expand Down Expand Up @@ -236,11 +243,45 @@ def dreamer_optimizer_fn(policy, config):
return (model_opt, actor_opt, critic_opt)


def preprocess_episode(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
"""Batch format should be in the form of (s_t, a_(t-1), r_(t-1))
When t=0, the resetted obs is paired with action and reward of 0.
"""
obs = sample_batch[SampleBatch.OBS]
new_obs = sample_batch[SampleBatch.NEXT_OBS]
action = sample_batch[SampleBatch.ACTIONS]
reward = sample_batch[SampleBatch.REWARDS]
eps_ids = sample_batch[SampleBatch.EPS_ID]

act_shape = action.shape
act_reset = np.array([0.0] * act_shape[-1])[None]
rew_reset = np.array(0.0)[None]
obs_end = np.array(new_obs[act_shape[0] - 1])[None]

batch_obs = np.concatenate([obs, obs_end], axis=0)
batch_action = np.concatenate([act_reset, action], axis=0)
batch_rew = np.concatenate([rew_reset, reward], axis=0)
batch_eps_ids = np.concatenate([eps_ids, eps_ids[-1:]], axis=0)

new_batch = {
SampleBatch.OBS: batch_obs,
SampleBatch.REWARDS: batch_rew,
SampleBatch.ACTIONS: batch_action,
SampleBatch.EPS_ID: batch_eps_ids,
}
return SampleBatch(new_batch)


DreamerTorchPolicy = build_policy_class(
name="DreamerTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG,
action_sampler_fn=action_sampler_fn,
postprocess_fn=preprocess_episode,
loss_fn=dreamer_loss,
stats_fn=dreamer_stats,
make_model=build_dreamer_model,
Expand Down
31 changes: 19 additions & 12 deletions rllib/agents/dreamer/tests/test_dreamer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from gym.spaces import Box
import unittest

import ray
from ray import tune
import ray.rllib.agents.dreamer as dreamer
from ray.rllib.examples.env.dm_control_suite import hopper_hop
from ray.rllib.utils.test_utils import check_compute_single_action, \
framework_iterator
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.test_utils import framework_iterator


class TestDreamer(unittest.TestCase):
Expand All @@ -20,19 +19,27 @@ def tearDown(self):
def test_dreamer_compilation(self):
"""Test whether an DreamerTrainer can be built with all frameworks."""
config = dreamer.DEFAULT_CONFIG.copy()
tune.register_env("dm_control_hopper_hop", lambda _: hopper_hop())
config["env_config"] = {
"observation_space": Box(-1.0, 1.0, (3, 64, 64)),
"action_space": Box(-1.0, 1.0, (3, ))
}
# Num episode chunks per batch.
config["batch_size"] = 2
# Length (ts) of an episode chunk in a batch.
config["batch_length"] = 20
# Sub-iterations per .train() call.
config["dreamer_train_iters"] = 4

num_iterations = 1

# Test against all frameworks.
for _ in framework_iterator(config, frameworks="torch"):
for env in ["dm_control_hopper_hop"]:
trainer = dreamer.DREAMERTrainer(config=config, env=env)
for i in range(num_iterations):
results = trainer.train()
print(results)
check_compute_single_action(trainer)
trainer.stop()
trainer = dreamer.DREAMERTrainer(config=config, env=RandomEnv)
for i in range(num_iterations):
results = trainer.train()
print(results)
# check_compute_single_action(trainer, include_state=True)
trainer.stop()


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions rllib/agents/dreamer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class TanhBijector(torch.distributions.Transform):
def __init__(self):
super().__init__()

self.bijective = True
self.domain = torch.distributions.constraints.real
self.codomain = torch.distributions.constraints.interval(-1.0, 1.0)

def atanh(self, x):
return 0.5 * torch.log((1 + x) / (1 - x))

Expand Down
30 changes: 11 additions & 19 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,8 @@ def __init__(
self.render = render

# Create the rollout generator to use for calls to `get_data()`.
self.rollout_provider = _env_runner(
worker, self.base_env, self.extra_batches.put,
self.rollout_fragment_length, self.horizon, clip_rewards,
self._env_runner = _env_runner(
worker, self.base_env, self.extra_batches.put, self.horizon,
normalize_actions, clip_actions, multiple_episodes_in_batch,
callbacks, self.perf_stats, soft_horizon, no_done_at_end,
observation_fn, self.sample_collector, self.render)
Expand All @@ -228,7 +227,7 @@ def __init__(
@override(SamplerInput)
def get_data(self) -> SampleBatchType:
while True:
item = next(self.rollout_provider)
item = next(self._env_runner)
if isinstance(item, RolloutMetrics):
self.metrics_queue.put(item)
else:
Expand Down Expand Up @@ -372,11 +371,11 @@ def __init__(
if not sample_collector_class:
sample_collector_class = SimpleListCollector
self.sample_collector = sample_collector_class(
worker.policy_map,
clip_rewards,
callbacks,
multiple_episodes_in_batch,
rollout_fragment_length,
self.worker.policy_map,
self.clip_rewards,
self.callbacks,
self.multiple_episodes_in_batch,
self.rollout_fragment_length,
count_steps_by=count_steps_by)

@override(threading.Thread)
Expand All @@ -395,9 +394,8 @@ def _run(self):
queue_putter = self.queue.put
extra_batches_putter = (
lambda x: self.extra_batches.put(x, timeout=600.0))
rollout_provider = _env_runner(
self.worker, self.base_env, extra_batches_putter,
self.rollout_fragment_length, self.horizon, self.clip_rewards,
env_runner = _env_runner(
self.worker, self.base_env, extra_batches_putter, self.horizon,
self.normalize_actions, self.clip_actions,
self.multiple_episodes_in_batch, self.callbacks, self.perf_stats,
self.soft_horizon, self.no_done_at_end, self.observation_fn,
Expand All @@ -406,7 +404,7 @@ def _run(self):
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
# set to some large number. This is an empirical observation.
item = next(rollout_provider)
item = next(env_runner)
if isinstance(item, RolloutMetrics):
self.metrics_queue.put(item)
else:
Expand Down Expand Up @@ -450,9 +448,7 @@ def _env_runner(
worker: "RolloutWorker",
base_env: BaseEnv,
extra_batch_callback: Callable[[SampleBatchType], None],
rollout_fragment_length: int,
horizon: int,
clip_rewards: bool,
normalize_actions: bool,
clip_actions: bool,
multiple_episodes_in_batch: bool,
Expand All @@ -470,11 +466,7 @@ def _env_runner(
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv.
extra_batch_callback (fn): function to send extra batch data to.
rollout_fragment_length (int): Number of episode steps before
`SampleBatch` is yielded. Set to infinity to yield complete
episodes.
horizon (int): Horizon of the episode.
clip_rewards (bool): Whether to clip rewards before postprocessing.
multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
Expand Down