Skip to content

Commit

Permalink
Add TQC support (#46)
Browse files Browse the repository at this point in the history
* Add TQC from sb3-contrib

* Update plot script

* Tuned hyperparams for Hopper

* Update plot

* Update hyperparams

* Update plot script: allow to merge files

* Make pytype happy

* Update humanoid params

* Revert Humanoids params

* Fix deps

* Fixes

* Add support for HER + TQC
  • Loading branch information
araffin authored Oct 22, 2020
1 parent 26dfece commit c7763b7
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New Features
- Added support for `HER`
- Added low-pass filter wrappers in `utils/wrappers.py`
- Added `TQC` support, implementation from sb3-contrib

### Bug fixes
- Fixed `TimeFeatureWrapper` inferring max timesteps
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pytest:

# Type check
type:
pytype ${LINT_PATHS}
pytype -j auto ${LINT_PATHS}

lint:
# stop the build if there are Python syntax errors or undefined names
Expand Down
10 changes: 5 additions & 5 deletions enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def main(): # noqa: C901

if args.exp_id == 0:
args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_id)
print("Loading latest experiment, id={}".format(args.exp_id))
print(f"Loading latest experiment, id={args.exp_id}")

# Sanity checks
if args.exp_id > 0:
log_path = os.path.join(folder, algo, "{}_{}".format(env_id, args.exp_id))
log_path = os.path.join(folder, algo, f"{env_id}_{args.exp_id}")
else:
log_path = os.path.join(folder, algo)

Expand All @@ -93,7 +93,7 @@ def main(): # noqa: C901
if not found:
raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}")

if algo in ["dqn", "ddpg", "sac", "td3"]:
if algo in ["dqn", "ddpg", "sac", "td3", "tqc"]:
args.n_envs = 1

set_random_seed(args.seed)
Expand Down Expand Up @@ -134,7 +134,7 @@ def main(): # noqa: C901
)

kwargs = dict(seed=args.seed)
if algo in ["dqn", "ddpg", "sac", "her", "td3"]:
if algo in ["dqn", "ddpg", "sac", "her", "td3", "tqc"]:
# Dummy buffer size as we don't need memory to enjoy the trained agent
kwargs.update(dict(buffer_size=1))

Expand All @@ -143,7 +143,7 @@ def main(): # noqa: C901
obs = env.reset()

# Force deterministic for DQN, DDPG, SAC and HER (that is a wrapper around)
deterministic = args.deterministic or algo in ["dqn", "ddpg", "sac", "her", "td3"] and not args.stochastic
deterministic = args.deterministic or algo in ["dqn", "ddpg", "sac", "her", "td3", "tqc"] and not args.stochastic

state = None
episode_reward = 0.0
Expand Down
202 changes: 202 additions & 0 deletions hyperparams/tqc.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Tuned
MountainCarContinuous-v0:
n_timesteps: !!float 50000
policy: 'MlpPolicy'
learning_rate: !!float 3e-4
buffer_size: 50000
batch_size: 512
ent_coef: 0.1
train_freq: 32
gradient_steps: 32
gamma: 0.9999
tau: 0.01
learning_starts: 0
use_sde: True
policy_kwargs: "dict(log_std_init=-3.67, net_arch=[64, 64])"

Pendulum-v0:
n_timesteps: 20000
policy: 'MlpPolicy'
learning_rate: !!float 1e-3
use_sde: True
n_episodes_rollout: 1
gradient_steps: -1
train_freq: -1
policy_kwargs: "dict(log_std_init=-2, net_arch=[64, 64])"

LunarLanderContinuous-v2:
n_timesteps: !!float 5e5
policy: 'MlpPolicy'
batch_size: 256
learning_starts: 1000

BipedalWalker-v3:
n_timesteps: !!float 5e5
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
tau: 0.02
train_freq: 64
gradient_steps: 64
learning_starts: 10000
use_sde: True
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"

# Almost tuned
# History wrapper of size 2 for better performances
BipedalWalkerHardcore-v3:
n_timesteps: !!float 2e6
policy: 'MlpPolicy'
learning_rate: lin_7.3e-4
buffer_size: 1000000
batch_size: 256
ent_coef: 'auto'
gamma: 0.99
tau: 0.01
train_freq: 64
gradient_steps: 64
learning_starts: 10000
use_sde: True
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300], use_expln=True)"

# === Bullet envs ===

# Tuned
HalfCheetahBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
tau: 0.02
train_freq: 64
gradient_steps: 64
learning_starts: 10000
use_sde: True
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"

# Tuned
AntBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
tau: 0.02
train_freq: 64
gradient_steps: 64
learning_starts: 10000
use_sde: True
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"

# Tuned
HopperBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
learning_rate: lin_7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
tau: 0.02
train_freq: 64
gradient_steps: 64
learning_starts: 10000
use_sde: True
top_quantiles_to_drop_per_net: 5
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"

# Tuned
Walker2DBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
learning_rate: lin_7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
tau: 0.02
train_freq: 64
gradient_steps: 64
learning_starts: 10000
use_sde: True
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"


ReacherBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
n_timesteps: !!float 3e5
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
tau: 0.02
train_freq: 64
gradient_steps: 64
learning_starts: 10000
use_sde: True
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"


# Almost tuned
HumanoidBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e7
policy: 'MlpPolicy'
learning_rate: lin_7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
tau: 0.02
train_freq: 64
gradient_steps: 64
learning_starts: 10000
top_quantiles_to_drop_per_net: 5
use_sde: True
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"

InvertedDoublePendulumBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
n_timesteps: !!float 5e5
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
tau: 0.02
train_freq: 64
gradient_steps: 64
learning_starts: 10000
use_sde: True
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"

InvertedPendulumSwingupBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
n_timesteps: !!float 3e5
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
tau: 0.02
train_freq: 64
gradient_steps: 64
learning_starts: 10000
use_sde: True
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
stable-baselines3[extra,tests,docs]>=0.10.0a0
stable-baselines3[extra,tests,docs]>=0.10.0a1
box2d-py==2.3.5
pybullet
gym-minigrid
Expand All @@ -7,3 +7,4 @@ optuna
pytablewriter
seaborn
pyyaml>=5.1
sb3-contrib>=0.10.0a1
27 changes: 25 additions & 2 deletions scripts/plot_from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
parser = argparse.ArgumentParser("Gather results, plot them and create table")
parser.add_argument("-i", "--input", help="Input filename (numpy archive)", type=str)
parser.add_argument("-skip", "--skip-envs", help="Environments to skip", nargs="+", default=[], type=str)
parser.add_argument("--keep-envs", help="Envs to keep", nargs="+", default=[], type=str)
parser.add_argument("--skip-keys", help="Keys to skip", nargs="+", default=[], type=str)
parser.add_argument("--keep-keys", help="Keys to keep", nargs="+", default=[], type=str)
parser.add_argument("--no-million", action="store_true", default=False, help="Do not convert x-axis to million")
parser.add_argument("--skip-timesteps", action="store_true", default=False, help="Do not display learning curves")
parser.add_argument("-o", "--output", help="Output filename (image)", type=str)
Expand All @@ -40,6 +43,7 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
parser.add_argument("-l", "--labels", help="Custom labels", type=str, nargs="+")
parser.add_argument("-b", "--boxplot", help="Enable boxplot", action="store_true", default=False)
parser.add_argument("-latex", "--latex", help="Enable latex support", action="store_true", default=False)
parser.add_argument("--merge", help="Merge with other results files", nargs="+", default=[], type=str)

args = parser.parse_args()

Expand Down Expand Up @@ -69,8 +73,26 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5

del results["results_table"]

keys = [key for key in results[list(results.keys())[0]].keys()]
for filename in args.merge:
# Merge other files
with open(filename, "rb") as file_handler:
results_2 = pickle.load(file_handler)
del results_2["results_table"]
for key in results.keys():
if key in results_2:
for new_key in results_2[key].keys():
results[key][new_key] = results_2[key][new_key]


keys = [key for key in results[list(results.keys())[0]].keys() if key not in args.skip_keys]
print(f"keys: {keys}")
if len(args.keep_keys) > 0:
keys = [key for key in keys if key in args.keep_keys]
envs = [env for env in results.keys() if env not in args.skip_envs]

if len(args.keep_envs) > 0:
envs = [env for env in envs if env in args.keep_envs]

labels = {key: key for key in keys}
if args.labels is not None:
for key, label in zip(keys, args.labels):
Expand Down Expand Up @@ -129,9 +151,10 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
# plt.title('Influence of the time feature', fontsize=args.fontsize)
# plt.title('Influence of the network architecture', fontsize=args.fontsize)
# plt.title('Influence of the exploration variance $log \sigma$', fontsize=args.fontsize)
plt.title("Influence of the sampling frequency", fontsize=args.fontsize)
# plt.title("Influence of the sampling frequency", fontsize=args.fontsize)
# plt.title('Parallel vs No Parallel Sampling', fontsize=args.fontsize)
# plt.title('Influence of the exploration function input', fontsize=args.fontsize)
plt.title("PyBullet envs", fontsize=args.fontsize)
plt.xticks(fontsize=13)
plt.xlabel("Environment", fontsize=args.fontsize)
plt.ylabel("Score", fontsize=args.fontsize)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
# HER is only a wrapper around an algo
if args.algo == "her":
algo_ = saved_hyperparams["model_class"]
assert algo_ in {"sac", "ddpg", "dqn", "td3"}, f"{algo_} is not compatible with HER"
assert algo_ in {"sac", "ddpg", "dqn", "td3", "tqc"}, f"{algo_} is not compatible with HER"
# Retrieve the model class
hyperparams["model_class"] = ALGOS[saved_hyperparams["model_class"]]

Expand Down
22 changes: 17 additions & 5 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,31 @@

import gym
import yaml

# from stable_baselines3.common import logger
from stable_baselines3 import A2C, DDPG, DQN, HER, PPO, SAC, TD3
from stable_baselines3.common.monitor import Monitor

# from stable_baselines3.common.cmd_util import make_atari_env
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize

try:
from sb3_contrib import TQC # pytype: disable=import-error
except ImportError:
TQC = None

# For custom activation fn
from torch import nn as nn # noqa: F401 pylint: disable=unused-import

ALGOS = {"a2c": A2C, "ddpg": DDPG, "dqn": DQN, "her": HER, "ppo": PPO, "sac": SAC, "td3": TD3}
ALGOS = {
"a2c": A2C,
"ddpg": DDPG,
"dqn": DQN,
"ppo": PPO,
"her": HER,
"sac": SAC,
"td3": TD3,
}

if TQC is not None:
ALGOS["tqc"] = TQC


def flatten_dict_observations(env):
Expand Down
4 changes: 1 addition & 3 deletions utils/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gym
import numpy as np
from matplotlib import pyplot as plt
from scipy.signal import iirfilter, sosfilt, zpk2sos


class DoneOnSuccessWrapper(gym.Wrapper):
Expand Down Expand Up @@ -170,9 +171,6 @@ def step(self, action):


# from https://docs.obspy.org
from scipy.signal import iirfilter, sosfilt, zpk2sos


def lowpass(data, freq, df, corners=4, zerophase=False):
"""
Butterworth-Lowpass Filter.
Expand Down

0 comments on commit c7763b7

Please sign in to comment.