Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train/Eval Mode Support #39

Merged
merged 4 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@ Changelog
==========


Release 1.2.0a0 (WIP)
Release 1.2.0 (2021-09-08)
-------------------------------

**Train/Eval mode support**

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.2.0

Bug Fixes:
^^^^^^^^^^
- QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -152,4 +156,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
Contributors:
-------------

@ku2482 @guyk1971 @minhlong94
@ku2482 @guyk1971 @minhlong94 @ayeright
10 changes: 10 additions & 0 deletions sb3_contrib/qrdqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def _build(self, lr_schedule: Schedule) -> None:
self.quantile_net = self.make_quantile_net()
self.quantile_net_target = self.make_quantile_net()
self.quantile_net_target.load_state_dict(self.quantile_net.state_dict())
self.quantile_net_target.set_training_mode(False)

# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
Expand Down Expand Up @@ -199,6 +200,15 @@ def _get_constructor_parameters(self) -> Dict[str, Any]:
)
return data

def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.quantile_net.set_training_mode(mode)
self.training = mode


MlpPolicy = QRDQNPolicy

Expand Down
2 changes: 2 additions & 0 deletions sb3_contrib/qrdqn/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def _on_step(self) -> None:
self.logger.record("rollout/exploration rate", self.exploration_rate)

def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update learning rate according to schedule
self._update_learning_rate(self.policy.optimizer)

Expand Down
13 changes: 13 additions & 0 deletions sb3_contrib/tqc/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,9 @@ def _build(self, lr_schedule: Schedule) -> None:
self.critic_target = self.make_critic(features_extractor=None)
self.critic_target.load_state_dict(self.critic.state_dict())

# Target networks should always be in eval mode
self.critic_target.set_training_mode(False)

self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)

def _get_constructor_parameters(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -423,6 +426,16 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.actor(observation, deterministic)

def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.actor.set_training_mode(mode)
self.critic.set_training_mode(mode)
self.training = mode


MlpPolicy = TQCPolicy

Expand Down
2 changes: 2 additions & 0 deletions sb3_contrib/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def _create_aliases(self) -> None:
self.critic_target = self.policy.critic_target

def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizers learning rate
optimizers = [self.actor.optimizer, self.critic.optimizer]
if self.ent_coef_optimizer is not None:
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.2.0a0
1.2.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.1.0",
"stable_baselines3>=1.2.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down
225 changes: 225 additions & 0 deletions tests/test_train_eval_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import gym
import numpy as np
import pytest
import torch as th
import torch.nn as nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

from sb3_contrib import QRDQN, TQC


class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
"""
Feature extract that flatten the input and applies batch normalization and dropout.
Used as a placeholder when feature extraction is not needed.
:param observation_space:
"""

def __init__(self, observation_space: gym.Space):
super(FlattenBatchNormDropoutExtractor, self).__init__(
observation_space,
get_flattened_obs_dim(observation_space),
)
self.flatten = nn.Flatten()
self.batch_norm = nn.BatchNorm1d(self._features_dim)
self.dropout = nn.Dropout(0.5)

def forward(self, observations: th.Tensor) -> th.Tensor:
result = self.flatten(observations)
result = self.batch_norm(result)
result = self.dropout(result)
return result


def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the given batch norm layer.
:param batch_norm:
:return: the bias and running mean
"""
return batch_norm.bias.clone(), batch_norm.running_mean.clone()


def clone_qrdqn_batch_norm_stats(model: QRDQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the quantile network and quantile-target network.
:param model:
:return: the bias and running mean from the quantile network and quantile-target network
"""
quantile_net_batch_norm = model.policy.quantile_net.features_extractor.batch_norm
quantile_net_bias, quantile_net_running_mean = clone_batch_norm_stats(quantile_net_batch_norm)

quantile_net_target_batch_norm = model.policy.quantile_net_target.features_extractor.batch_norm
quantile_net_target_bias, quantile_net_target_running_mean = clone_batch_norm_stats(quantile_net_target_batch_norm)

return quantile_net_bias, quantile_net_running_mean, quantile_net_target_bias, quantile_net_target_running_mean


def clone_tqc_batch_norm_stats(
model: TQC,
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the actor and critic networks and critic-target networks.
:param model:
:return: the bias and running mean from the actor and critic networks and critic-target networks
"""
actor_batch_norm = model.actor.features_extractor.batch_norm
actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)

critic_batch_norm = model.critic.features_extractor.batch_norm
critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)

critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)

return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean)


CLONE_HELPERS = {
QRDQN: clone_qrdqn_batch_norm_stats,
TQC: clone_tqc_batch_norm_stats,
}


def test_qrdqn_train_with_batch_norm():
model = QRDQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
seed=1,
tau=0, # do not clone the target
)

(
quantile_net_bias_before,
quantile_net_running_mean_before,
quantile_net_target_bias_before,
quantile_net_target_running_mean_before,
) = clone_qrdqn_batch_norm_stats(model)

model.learn(total_timesteps=200)

(
quantile_net_bias_after,
quantile_net_running_mean_after,
quantile_net_target_bias_after,
quantile_net_target_running_mean_after,
) = clone_qrdqn_batch_norm_stats(model)

assert ~th.isclose(quantile_net_bias_before, quantile_net_bias_after).all()
assert ~th.isclose(quantile_net_running_mean_before, quantile_net_running_mean_after).all()

assert th.isclose(quantile_net_target_bias_before, quantile_net_target_bias_after).all()
assert th.isclose(quantile_net_target_running_mean_before, quantile_net_target_running_mean_after).all()


def test_tqc_train_with_batch_norm():
model = TQC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
tau=0, # do not copy the target
seed=1,
)

(
actor_bias_before,
actor_running_mean_before,
critic_bias_before,
critic_running_mean_before,
critic_target_bias_before,
critic_target_running_mean_before,
) = clone_tqc_batch_norm_stats(model)

model.learn(total_timesteps=200)

(
actor_bias_after,
actor_running_mean_after,
critic_bias_after,
critic_running_mean_after,
critic_target_bias_after,
critic_target_running_mean_after,
) = clone_tqc_batch_norm_stats(model)

assert ~th.isclose(actor_bias_before, actor_bias_after).all()
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()

assert ~th.isclose(critic_bias_before, critic_bias_after).all()
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()

assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()


@pytest.mark.parametrize("model_class", [QRDQN, TQC])
def test_offpolicy_collect_rollout_batch_norm(model_class):
if model_class in [QRDQN]:
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v0"

clone_helper = CLONE_HELPERS[model_class]

learning_starts = 10
model = model_class(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=learning_starts,
seed=1,
gradient_steps=0,
train_freq=1,
)

batch_norm_stats_before = clone_helper(model)

model.learn(total_timesteps=100)

batch_norm_stats_after = clone_helper(model)

# No change in batch norm params
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
assert th.isclose(param_before, param_after).all()


@pytest.mark.parametrize("model_class", [QRDQN, TQC])
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
def test_predict_with_dropout_batch_norm(model_class, env_id):
if env_id == "CartPole-v1":
if model_class in [TQC]:
return
elif model_class in [QRDQN]:
return

model_kwargs = dict(seed=1)
clone_helper = CLONE_HELPERS[model_class]

if model_class in [QRDQN, TQC]:
model_kwargs["learning_starts"] = 0
else:
model_kwargs["n_steps"] = 64

policy_kwargs = dict(
features_extractor_class=FlattenBatchNormDropoutExtractor,
net_arch=[16, 16],
)
model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs)

batch_norm_stats_before = clone_helper(model)

env = model.get_env()
observation = env.reset()
first_prediction, _ = model.predict(observation, deterministic=True)
for _ in range(5):
prediction, _ = model.predict(observation, deterministic=True)
np.testing.assert_allclose(first_prediction, prediction)

batch_norm_stats_after = clone_helper(model)

# No change in batch norm params
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
assert th.isclose(param_before, param_after).all()