Skip to content

Commit

Permalink
Train/Eval Mode Support (Stable-Baselines-Team#39)
Browse files Browse the repository at this point in the history
* switch models between train and eval mode

* update changelog

* update release in change log

* Update dependency

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
ayeright and araffin authored Sep 8, 2021
1 parent 36eca8e commit b2e7126
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 4 deletions.
8 changes: 6 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@ Changelog
==========


Release 1.2.0a0 (WIP)
Release 1.2.0 (2021-09-08)
-------------------------------

**Train/Eval mode support**

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.2.0

Bug Fixes:
^^^^^^^^^^
- QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -152,4 +156,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
Contributors:
-------------

@ku2482 @guyk1971 @minhlong94
@ku2482 @guyk1971 @minhlong94 @ayeright
10 changes: 10 additions & 0 deletions sb3_contrib/qrdqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def _build(self, lr_schedule: Schedule) -> None:
self.quantile_net = self.make_quantile_net()
self.quantile_net_target = self.make_quantile_net()
self.quantile_net_target.load_state_dict(self.quantile_net.state_dict())
self.quantile_net_target.set_training_mode(False)

# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
Expand Down Expand Up @@ -199,6 +200,15 @@ def _get_constructor_parameters(self) -> Dict[str, Any]:
)
return data

def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.quantile_net.set_training_mode(mode)
self.training = mode


MlpPolicy = QRDQNPolicy

Expand Down
2 changes: 2 additions & 0 deletions sb3_contrib/qrdqn/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def _on_step(self) -> None:
self.logger.record("rollout/exploration rate", self.exploration_rate)

def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update learning rate according to schedule
self._update_learning_rate(self.policy.optimizer)

Expand Down
13 changes: 13 additions & 0 deletions sb3_contrib/tqc/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,9 @@ def _build(self, lr_schedule: Schedule) -> None:
self.critic_target = self.make_critic(features_extractor=None)
self.critic_target.load_state_dict(self.critic.state_dict())

# Target networks should always be in eval mode
self.critic_target.set_training_mode(False)

self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)

def _get_constructor_parameters(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -423,6 +426,16 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.actor(observation, deterministic)

def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.actor.set_training_mode(mode)
self.critic.set_training_mode(mode)
self.training = mode


MlpPolicy = TQCPolicy

Expand Down
2 changes: 2 additions & 0 deletions sb3_contrib/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def _create_aliases(self) -> None:
self.critic_target = self.policy.critic_target

def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizers learning rate
optimizers = [self.actor.optimizer, self.critic.optimizer]
if self.ent_coef_optimizer is not None:
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.2.0a0
1.2.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.1.0",
"stable_baselines3>=1.2.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down
225 changes: 225 additions & 0 deletions tests/test_train_eval_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import gym
import numpy as np
import pytest
import torch as th
import torch.nn as nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

from sb3_contrib import QRDQN, TQC


class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
"""
Feature extract that flatten the input and applies batch normalization and dropout.
Used as a placeholder when feature extraction is not needed.
:param observation_space:
"""

def __init__(self, observation_space: gym.Space):
super(FlattenBatchNormDropoutExtractor, self).__init__(
observation_space,
get_flattened_obs_dim(observation_space),
)
self.flatten = nn.Flatten()
self.batch_norm = nn.BatchNorm1d(self._features_dim)
self.dropout = nn.Dropout(0.5)

def forward(self, observations: th.Tensor) -> th.Tensor:
result = self.flatten(observations)
result = self.batch_norm(result)
result = self.dropout(result)
return result


def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the given batch norm layer.
:param batch_norm:
:return: the bias and running mean
"""
return batch_norm.bias.clone(), batch_norm.running_mean.clone()


def clone_qrdqn_batch_norm_stats(model: QRDQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the quantile network and quantile-target network.
:param model:
:return: the bias and running mean from the quantile network and quantile-target network
"""
quantile_net_batch_norm = model.policy.quantile_net.features_extractor.batch_norm
quantile_net_bias, quantile_net_running_mean = clone_batch_norm_stats(quantile_net_batch_norm)

quantile_net_target_batch_norm = model.policy.quantile_net_target.features_extractor.batch_norm
quantile_net_target_bias, quantile_net_target_running_mean = clone_batch_norm_stats(quantile_net_target_batch_norm)

return quantile_net_bias, quantile_net_running_mean, quantile_net_target_bias, quantile_net_target_running_mean


def clone_tqc_batch_norm_stats(
model: TQC,
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the actor and critic networks and critic-target networks.
:param model:
:return: the bias and running mean from the actor and critic networks and critic-target networks
"""
actor_batch_norm = model.actor.features_extractor.batch_norm
actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)

critic_batch_norm = model.critic.features_extractor.batch_norm
critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)

critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)

return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean)


CLONE_HELPERS = {
QRDQN: clone_qrdqn_batch_norm_stats,
TQC: clone_tqc_batch_norm_stats,
}


def test_qrdqn_train_with_batch_norm():
model = QRDQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
seed=1,
tau=0, # do not clone the target
)

(
quantile_net_bias_before,
quantile_net_running_mean_before,
quantile_net_target_bias_before,
quantile_net_target_running_mean_before,
) = clone_qrdqn_batch_norm_stats(model)

model.learn(total_timesteps=200)

(
quantile_net_bias_after,
quantile_net_running_mean_after,
quantile_net_target_bias_after,
quantile_net_target_running_mean_after,
) = clone_qrdqn_batch_norm_stats(model)

assert ~th.isclose(quantile_net_bias_before, quantile_net_bias_after).all()
assert ~th.isclose(quantile_net_running_mean_before, quantile_net_running_mean_after).all()

assert th.isclose(quantile_net_target_bias_before, quantile_net_target_bias_after).all()
assert th.isclose(quantile_net_target_running_mean_before, quantile_net_target_running_mean_after).all()


def test_tqc_train_with_batch_norm():
model = TQC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
tau=0, # do not copy the target
seed=1,
)

(
actor_bias_before,
actor_running_mean_before,
critic_bias_before,
critic_running_mean_before,
critic_target_bias_before,
critic_target_running_mean_before,
) = clone_tqc_batch_norm_stats(model)

model.learn(total_timesteps=200)

(
actor_bias_after,
actor_running_mean_after,
critic_bias_after,
critic_running_mean_after,
critic_target_bias_after,
critic_target_running_mean_after,
) = clone_tqc_batch_norm_stats(model)

assert ~th.isclose(actor_bias_before, actor_bias_after).all()
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()

assert ~th.isclose(critic_bias_before, critic_bias_after).all()
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()

assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()


@pytest.mark.parametrize("model_class", [QRDQN, TQC])
def test_offpolicy_collect_rollout_batch_norm(model_class):
if model_class in [QRDQN]:
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v0"

clone_helper = CLONE_HELPERS[model_class]

learning_starts = 10
model = model_class(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=learning_starts,
seed=1,
gradient_steps=0,
train_freq=1,
)

batch_norm_stats_before = clone_helper(model)

model.learn(total_timesteps=100)

batch_norm_stats_after = clone_helper(model)

# No change in batch norm params
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
assert th.isclose(param_before, param_after).all()


@pytest.mark.parametrize("model_class", [QRDQN, TQC])
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
def test_predict_with_dropout_batch_norm(model_class, env_id):
if env_id == "CartPole-v1":
if model_class in [TQC]:
return
elif model_class in [QRDQN]:
return

model_kwargs = dict(seed=1)
clone_helper = CLONE_HELPERS[model_class]

if model_class in [QRDQN, TQC]:
model_kwargs["learning_starts"] = 0
else:
model_kwargs["n_steps"] = 64

policy_kwargs = dict(
features_extractor_class=FlattenBatchNormDropoutExtractor,
net_arch=[16, 16],
)
model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs)

batch_norm_stats_before = clone_helper(model)

env = model.get_env()
observation = env.reset()
first_prediction, _ = model.predict(observation, deterministic=True)
for _ in range(5):
prediction, _ = model.predict(observation, deterministic=True)
np.testing.assert_allclose(first_prediction, prediction)

batch_norm_stats_after = clone_helper(model)

# No change in batch norm params
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
assert th.isclose(param_before, param_after).all()

0 comments on commit b2e7126

Please sign in to comment.