Skip to content

Commit

Permalink
Change SAC library to use the implementation written by pranz24 (face…
Browse files Browse the repository at this point in the history
…bookresearch#142)

* Added pytorch_sac_pranz24 dependency.

* Changes to pytorch_sac_pranz24 to make it easier to integrate with MBRL-Lib.

* Changed SACAgent to use pranz24's SAC.

* Added test for new add_batch of pranz24's SAC.

* Changed MBPO to use pranz24's SAC.

* Update to MBPO config files.

* Fixed broken MBPO test.

* Added batched= option to SAC.select_action (pranz24's).

* Fixed bug in complete_agent_cfg caused by 'action_space' key.

* Fixed incompatibility between pranz24's GaussianPolicy and hydra.

* Added logger to pranz24.SAC

* Added option to change target entropy in pranz24.SAC

* Added add_batch method to mbrl.util.ReplayBuffer.

* Changed pranz24's SAC to use mbrl.util.ReplayBuffer.

* Added mbrl.util.Logger to pranz24's SAC.

* Added --target_entropy arg to pranz24's SAC.

* Changed MBPO to use mbrl.util.ReplayBuffer both for model and SAC.

* Added option to tell pranz24's SAC.update_parameters() to use ~masks.

* Added mbrl.planning.load_agent implementation for new SAC.

* Updated config files for new SAC.

* Added mbrl.Logger to pytorch_sac_pranz24 and removed tensorboard.

* Added option to use real data for SAC in MBPO, with some small probability.

* [bug-fix] Fixed incorrect termination function for humanoid in make_env.

* Updated MBPO config for inverted pendulum.

* Updated config files for hopper and humanoid.

* Updated configs for MBPO on ant, humanoid and walker.

* Run black and update CHANGELOG.
  • Loading branch information
luisenp authored Feb 3, 2022
1 parent 85e471c commit 4543fc9
Show file tree
Hide file tree
Showing 38 changed files with 1,169 additions and 234 deletions.
11 changes: 5 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
repos:
- repo: https://github.com/psf/black
rev: 21.9b0
rev: 22.1.0
hooks:
- id: black
files: 'mbrl'
language_version: python3.7

- repo: https://gitlab.com/pycqa/flake8
rev: 3.7.9
rev: 3.9.2
hooks:
- id: flake8
files: 'mbrl'
additional_dependencies: [-e, "git+git://github.com/pycqa/pyflakes.git@1911c20#egg=pyflakes"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
rev: v0.931
hooks:
- id: mypy
files: 'mbrl'
additional_dependencies: [torch, tokenize-rt==3.2.0]
additional_dependencies: [torch, tokenize-rt==3.2.0, types-PyYAML, types-termcolor]
args: [--no-strict-optional, --ignore-missing-imports]
exclude: setup.py

- repo: https://github.com/pycqa/isort
rev: 5.5.2
rev: 5.10.1
hooks:
- id: isort
args: ["--profile", "black"]
Expand Down
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Changelog

## main (v0.2.0.dev3)
## main (v0.2.0.dev4)
### Main new features
- Added [PlaNet](http://proceedings.mlr.press/v97/hafner19a/hafner19a.pdf) implementation.
- Added support for [PyBullet](https://pybullet.org/wordpress/) environments.
- Changed SAC library used by MBPO
(now based on [Pranjan Tadon's](https://github.com/pranz24/pytorch-soft-actor-critic)).

### Breaking changes
- `Model.reset()` and `Model.sample()` signature has changed. They no longer receive
`TransitionBatch` objects, and they both return a dictionary of strings to tensors
Expand All @@ -16,8 +22,10 @@
is an `omegaconf.DictConfig` specifying the class to use for the activation functions,
thus giving more flexibility.
- Removed unnecessary nesting inside `dynamics_model` Hydra configuration.
- SAC agents prior to v0.2.0 cannot be loaded anymore.

### Other changes
- Added `add_batch()` method to `mbrl.util.ReplayBuffer`.
- Added functions to `mbrl.util.models` to easily create convolutional encoder/decoders
with a desired configuration.
- `mbrl.util.common.rollout_agent_trajectories` now allows rolling out a pixel-based
Expand All @@ -29,6 +37,7 @@ truncated normal distribution.
- `mbrl.util.mujoco.make_env` can now create an environment specified via an `omegaconf`
configuration and `hydra.utils.instantiate`, which takes precedence over the old
mechanism if both are present.
- Fixed bug that assigned wrong termination functino to `humanoid_truncated_obs` env.

## v0.1.4
- Added MPPI optimizer.
Expand Down
2 changes: 1 addition & 1 deletion mbrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
__version__ = "0.2.0.dev3"
__version__ = "0.2.0.dev4"
79 changes: 39 additions & 40 deletions mbrl/algorithms/mbpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Optional, Tuple, cast
from typing import Optional, Sequence, cast

import gym
import hydra.utils
Expand All @@ -14,12 +14,13 @@
import mbrl.constants
import mbrl.models
import mbrl.planning
import mbrl.third_party.pytorch_sac as pytorch_sac
import mbrl.third_party.pytorch_sac_pranz24 as pytorch_sac_pranz24
import mbrl.types
import mbrl.util
import mbrl.util.common
import mbrl.util.math
from mbrl.planning.sac_wrapper import SACAgent
from mbrl.third_party.pytorch_sac import VideoRecorder

MBPO_LOG_FORMAT = mbrl.constants.EVAL_LOG_FORMAT + [
("epoch", "E", "int"),
Expand All @@ -31,7 +32,7 @@ def rollout_model_and_populate_sac_buffer(
model_env: mbrl.models.ModelEnv,
replay_buffer: mbrl.util.ReplayBuffer,
agent: SACAgent,
sac_buffer: pytorch_sac.ReplayBuffer,
sac_buffer: mbrl.util.ReplayBuffer,
sac_samples_action: bool,
rollout_horizon: int,
batch_size: int,
Expand All @@ -53,20 +54,19 @@ def rollout_model_and_populate_sac_buffer(
sac_buffer.add_batch(
obs[~accum_dones],
action[~accum_dones],
pred_rewards[~accum_dones],
pred_next_obs[~accum_dones],
pred_dones[~accum_dones],
pred_dones[~accum_dones],
pred_rewards[~accum_dones, 0],
pred_dones[~accum_dones, 0],
)
obs = pred_next_obs
accum_dones |= pred_dones.squeeze()


def evaluate(
env: gym.Env,
agent: pytorch_sac.Agent,
agent: SACAgent,
num_episodes: int,
video_recorder: pytorch_sac.VideoRecorder,
video_recorder: VideoRecorder,
) -> float:
avg_episode_reward = 0
for episode in range(num_episodes):
Expand All @@ -84,27 +84,22 @@ def evaluate(


def maybe_replace_sac_buffer(
sac_buffer: Optional[pytorch_sac.ReplayBuffer],
sac_buffer: Optional[mbrl.util.ReplayBuffer],
obs_shape: Sequence[int],
act_shape: Sequence[int],
new_capacity: int,
obs_shape: Tuple[int],
act_shape: Tuple[int],
device: torch.device,
) -> pytorch_sac.ReplayBuffer:
seed: int,
) -> mbrl.util.ReplayBuffer:
if sac_buffer is None or new_capacity != sac_buffer.capacity:
new_buffer = pytorch_sac.ReplayBuffer(
obs_shape, act_shape, new_capacity, device
)
if sac_buffer is None:
rng = np.random.default_rng(seed=seed)
else:
rng = sac_buffer.rng
new_buffer = mbrl.util.ReplayBuffer(new_capacity, obs_shape, act_shape, rng=rng)
if sac_buffer is None:
return new_buffer
n = len(sac_buffer)
new_buffer.add_batch(
sac_buffer.obses[:n],
sac_buffer.actions[:n],
sac_buffer.rewards[:n],
sac_buffer.next_obses[:n],
np.logical_not(sac_buffer.not_dones[:n]),
np.logical_not(sac_buffer.not_dones_no_max[:n]),
)
obs, action, next_obs, reward, done = sac_buffer.get_all().astuple()
new_buffer.add_batch(obs, action, next_obs, reward, done)
return new_buffer
return sac_buffer

Expand All @@ -124,7 +119,9 @@ def train(
act_shape = env.action_space.shape

mbrl.planning.complete_agent_cfg(env, cfg.algorithm.agent)
agent = hydra.utils.instantiate(cfg.algorithm.agent)
agent = SACAgent(
cast(pytorch_sac_pranz24.SAC, hydra.utils.instantiate(cfg.algorithm.agent))
)

work_dir = work_dir or os.getcwd()
# enable_back_compatible to use pytorch_sac agent
Expand All @@ -136,7 +133,7 @@ def train(
dump_frequency=1,
)
save_video = cfg.get("save_video", False)
video_recorder = pytorch_sac.VideoRecorder(work_dir if save_video else None)
video_recorder = VideoRecorder(work_dir if save_video else None)

rng = np.random.default_rng(seed=cfg.seed)
torch_generator = torch.Generator(device=cfg.device)
Expand Down Expand Up @@ -196,11 +193,7 @@ def train(
sac_buffer_capacity = rollout_length * rollout_batch_size * trains_per_epoch
sac_buffer_capacity *= cfg.overrides.num_epochs_to_retain_sac_buffer
sac_buffer = maybe_replace_sac_buffer(
sac_buffer,
sac_buffer_capacity,
obs_shape,
act_shape,
torch.device(cfg.device),
sac_buffer, obs_shape, act_shape, sac_buffer_capacity, cfg.seed
)
obs, done = None, False
for steps_epoch in range(cfg.overrides.epoch_length):
Expand Down Expand Up @@ -243,11 +236,20 @@ def train(

# --------------- Agent Training -----------------
for _ in range(cfg.overrides.num_sac_updates_per_step):
use_real_data = rng.random() < cfg.algorithm.real_data_ratio
which_buffer = replay_buffer if use_real_data else sac_buffer
if (env_steps + 1) % cfg.overrides.sac_updates_every_steps != 0 or len(
sac_buffer
) < rollout_batch_size:
which_buffer
) < cfg.overrides.sac_batch_size:
break # only update every once in a while
agent.update(sac_buffer, logger, updates_made)

agent.sac_agent.update_parameters(
which_buffer,
cfg.overrides.sac_batch_size,
updates_made,
logger,
reverse_mask=True,
)
updates_made += 1
if not silent and updates_made % cfg.log_frequency_agent == 0:
logger.dump(updates_made, save=True)
Expand All @@ -269,11 +271,8 @@ def train(
if avg_reward > best_eval_reward:
video_recorder.save(f"{epoch}.mp4")
best_eval_reward = avg_reward
torch.save(
agent.critic.state_dict(), os.path.join(work_dir, "critic.pth")
)
torch.save(
agent.actor.state_dict(), os.path.join(work_dir, "actor.pth")
agent.sac_agent.save_checkpoint(
ckpt_path=os.path.join(work_dir, "sac.pth")
)
epoch += 1

Expand Down
4 changes: 2 additions & 2 deletions mbrl/env/cartpole_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def step(self, action):
# For the interested reader:
# https://coneural.org/florian/papers/05_cart_pole.pdf
temp = (
force + self.polemass_length * theta_dot ** 2 * sintheta
force + self.polemass_length * theta_dot**2 * sintheta
) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta * temp) / (
self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass)
self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
)
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

Expand Down
2 changes: 1 addition & 1 deletion mbrl/env/pets_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def step(self, a):
self._get_ee_pos(ob) - np.array([0.0, CartPoleEnv.PENDULUM_LENGTH])
)
)
/ (cost_lscale ** 2)
/ (cost_lscale**2)
)
reward -= 0.01 * np.sum(np.square(a))

Expand Down
6 changes: 3 additions & 3 deletions mbrl/env/reward_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def cartpole_pets(act: torch.Tensor, next_obs: torch.Tensor) -> torch.Tensor:
x0 = next_obs[:, :1]
theta = next_obs[:, 1:2]
ee_pos = torch.cat([x0 - 0.6 * theta.sin(), -0.6 * theta.cos()], dim=1)
obs_cost = torch.exp(-torch.sum((ee_pos - goal_pos) ** 2, dim=1) / (0.6 ** 2))
act_cost = -0.01 * torch.sum(act ** 2, dim=1)
obs_cost = torch.exp(-torch.sum((ee_pos - goal_pos) ** 2, dim=1) / (0.6**2))
act_cost = -0.01 * torch.sum(act**2, dim=1)
return (obs_cost + act_cost).view(-1, 1)


Expand Down Expand Up @@ -48,6 +48,6 @@ def pusher(act: torch.Tensor, next_obs: torch.Tensor) -> torch.Tensor:
obj_goal_dist = (goal_pos - obj_pos).abs().sum(axis=1)
obs_cost = to_w * tip_obj_dist + og_w * obj_goal_dist

act_cost = 0.1 * (act ** 2).sum(axis=1)
act_cost = 0.1 * (act**2).sum(axis=1)

return -(obs_cost + act_cost).view(-1, 1)
55 changes: 19 additions & 36 deletions mbrl/examples/conf/algorithm/mbpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ normalize_double_precision: true
target_is_delta: true
learned_rewards: true
freq_train_model: ${overrides.freq_train_model}
real_data_ratio: 0.0

sac_samples_action: true
initial_exploration_steps: 5000
Expand All @@ -16,39 +17,21 @@ num_eval_episodes: 1
# SAC Agent configuration
# --------------------------------------------
agent:
_target_: mbrl.third_party.pytorch_sac.agent.sac.SACAgent
obs_dim: ??? # to be specified later
action_dim: ??? # to be specified later
action_range: ??? # to be specified later
device: ${device}
critic_cfg: ${algorithm.double_q_critic}
actor_cfg: ${algorithm.diag_gaussian_actor}
discount: 0.99
init_temperature: 0.1
alpha_lr: ${overrides.sac_alpha_lr}
alpha_betas: [0.9, 0.999]
actor_lr: ${overrides.sac_actor_lr}
actor_betas: [0.9, 0.999]
actor_update_frequency: ${overrides.sac_actor_update_frequency}
critic_lr: ${overrides.sac_critic_lr}
critic_betas: [0.9, 0.999]
critic_tau: 0.005
critic_target_update_frequency: ${overrides.sac_critic_target_update_frequency}
batch_size: 256
learnable_temperature: true
target_entropy: ${overrides.sac_target_entropy}

double_q_critic:
_target_: mbrl.third_party.pytorch_sac.agent.critic.DoubleQCritic
obs_dim: ${algorithm.agent.obs_dim}
action_dim: ${algorithm.agent.action_dim}
hidden_dim: 1024
hidden_depth: ${overrides.sac_hidden_depth}

diag_gaussian_actor:
_target_: mbrl.third_party.pytorch_sac.agent.actor.DiagGaussianActor
obs_dim: ${algorithm.agent.obs_dim}
action_dim: ${algorithm.agent.action_dim}
hidden_depth: ${overrides.sac_hidden_depth}
hidden_dim: 1024
log_std_bounds: [-5, 2]
_target_: mbrl.third_party.pytorch_sac_pranz24.sac.SAC
num_inputs: ???
action_space:
_target_: gym.env.Box
low: ???
high: ???
shape: ???
args:
gamma: ${overrides.sac_gamma}
tau: ${overrides.sac_tau}
alpha: ${overrides.sac_alpha}
policy: ${overrides.sac_policy}
target_update_interval: ${overrides.sac_target_update_interval}
automatic_entropy_tuning: ${overrides.sac_automatic_entropy_tuning}
target_entropy: ${overrides.sac_target_entropy}
hidden_size: ${overrides.sac_hidden_size}
device: ${device}
lr: ${overrides.sac_lr}
20 changes: 12 additions & 8 deletions mbrl/examples/conf/overrides/mbpo_ant.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# @package _group_
env: "ant_truncated_obs"
# term_fn is set automatically by mbrl.util.env.EnvHandler.make_env

num_steps: 300000
epoch_length: 1000
num_elites: 5
patience: 10
model_lr: 0.0003
model_wd: 0.00002
model_wd: 5e-5
model_batch_size: 256
validation_ratio: 0.2
freq_train_model: 250
Expand All @@ -16,10 +17,13 @@ num_sac_updates_per_step: 20
sac_updates_every_steps: 1
num_epochs_to_retain_sac_buffer: 1

sac_critic_lr: 0.0003
sac_alpha_lr: 0.003
sac_actor_lr: 0.0003
sac_actor_update_frequency: 4
sac_critic_target_update_frequency: 1
sac_target_entropy: -3
sac_hidden_depth: 2
sac_gamma: 0.99
sac_tau: 0.005
sac_alpha: 0.2
sac_policy: "Gaussian"
sac_target_update_interval: 4
sac_automatic_entropy_tuning: false
sac_target_entropy: -1 # ignored, since entropy tuning is false
sac_hidden_size: 1024
sac_lr: 0.0001
sac_batch_size: 256
Loading

0 comments on commit 4543fc9

Please sign in to comment.