Skip to content

Behavioral Cloning Pytorch #4293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mlagents_envs.base_env import DecisionSteps

from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.components.bc.module import BCModule
from mlagents.trainers.torch.components.bc.module import BCModule
from mlagents.trainers.torch.components.reward_providers import create_reward_provider

from mlagents.trainers.policy.torch_policy import TorchPolicy
Expand All @@ -27,6 +27,14 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
self.global_step = torch.tensor(0)
self.bc_module: Optional[BCModule] = None
self.create_reward_signals(trainer_settings.reward_signals)
if trainer_settings.behavioral_cloning is not None:
self.bc_module = BCModule(
self.policy,
trainer_settings.behavioral_cloning,
policy_learning_rate=trainer_settings.hyperparameters.learning_rate,
default_batch_size=trainer_settings.hyperparameters.batch_size,
default_num_epoch=3,
)

def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
pass
Expand Down
Binary file not shown.
150 changes: 150 additions & 0 deletions ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from unittest.mock import MagicMock
import pytest
import mlagents.trainers.tests.mock_brain as mb

import numpy as np
import os

from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.torch.components.bc.module import BCModule
from mlagents.trainers.settings import (
TrainerSettings,
BehavioralCloningSettings,
NetworkSettings,
)


def create_bc_module(mock_behavior_specs, bc_settings, use_rnn, tanhresample):
# model_path = env.external_brain_names[0]
trainer_config = TrainerSettings()
trainer_config.network_settings.memory = (
NetworkSettings.MemorySettings() if use_rnn else None
)
policy = TorchPolicy(
0,
mock_behavior_specs,
trainer_config,
"test",
False,
tanhresample,
tanhresample,
)
bc_module = BCModule(
policy,
settings=bc_settings,
policy_learning_rate=trainer_config.hyperparameters.learning_rate,
default_batch_size=trainer_config.hyperparameters.batch_size,
default_num_epoch=3,
)
return bc_module


# Test default values
def test_bcmodule_defaults():
# See if default values match
mock_specs = mb.create_mock_3dball_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
)
bc_module = create_bc_module(mock_specs, bc_settings, False, False)
assert bc_module.num_epoch == 3
assert bc_module.batch_size == TrainerSettings().hyperparameters.batch_size
# Assign strange values and see if it overrides properly
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
num_epoch=100,
batch_size=10000,
)
bc_module = create_bc_module(mock_specs, bc_settings, False, False)
assert bc_module.num_epoch == 100
assert bc_module.batch_size == 10000


# Test with continuous control env and vector actions
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_update(is_sac):
mock_specs = mb.create_mock_3dball_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
)
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
stats = bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)


# Test with constant pretraining learning rate
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_constant_lr_update(is_sac):
mock_specs = mb.create_mock_3dball_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
steps=0,
)
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
stats = bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)
old_learning_rate = bc_module.current_lr

_ = bc_module.update()
assert old_learning_rate == bc_module.current_lr


# Test with constant pretraining learning rate
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_linear_lr_update(is_sac):
mock_specs = mb.create_mock_3dball_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
steps=100,
)
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
# Should decay by 10/100 * 0.0003 = 0.00003
bc_module.policy.get_current_step = MagicMock(return_value=10)
old_learning_rate = bc_module.current_lr
_ = bc_module.update()
assert old_learning_rate - 0.00003 == pytest.approx(bc_module.current_lr, abs=0.01)


# Test with RNN
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_rnn_update(is_sac):
mock_specs = mb.create_mock_3dball_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
)
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
stats = bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)


# Test with discrete control and visual observations
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_dc_visual_update(is_sac):
mock_specs = mb.create_mock_banana_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo"
)
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
stats = bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)


# Test with discrete control, visual observations and RNN
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_rnn_dc_update(is_sac):
mock_specs = mb.create_mock_banana_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo"
)
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
stats = bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)


if __name__ == "__main__":
pytest.main()
Binary file not shown.
Empty file.
185 changes: 185 additions & 0 deletions ml-agents/mlagents/trainers/torch/components/bc/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from typing import Dict
import numpy as np
import torch

from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.demo_loader import demo_to_buffer
from mlagents.trainers.settings import BehavioralCloningSettings, ScheduleType
from mlagents.trainers.torch.utils import ModelUtils


class BCModule:
def __init__(
self,
policy: TorchPolicy,
settings: BehavioralCloningSettings,
policy_learning_rate: float,
default_batch_size: int,
default_num_epoch: int,
):
"""
A BC trainer that can be used inline with RL.
:param policy: The policy of the learning model
:param settings: The settings for BehavioralCloning including LR strength, batch_size,
num_epochs, samples_per_update and LR annealing steps.
:param policy_learning_rate: The initial Learning Rate of the policy. Used to set an appropriate learning rate
for the pretrainer.
"""
self.policy = policy
self._anneal_steps = settings.steps
self.current_lr = policy_learning_rate * settings.strength

learning_rate_schedule: ScheduleType = ScheduleType.LINEAR if self._anneal_steps > 0 else ScheduleType.CONSTANT
self.decay_learning_rate = ModelUtils.DecayedValue(
learning_rate_schedule, self.current_lr, 1e-10, self._anneal_steps
)
params = self.policy.actor_critic.parameters()
self.optimizer = torch.optim.Adam(params, lr=self.current_lr)
_, self.demonstration_buffer = demo_to_buffer(
settings.demo_path, policy.sequence_length, policy.behavior_spec
)

self.batch_size = (
settings.batch_size if settings.batch_size else default_batch_size
)
self.num_epoch = settings.num_epoch if settings.num_epoch else default_num_epoch
self.n_sequences = max(
min(self.batch_size, self.demonstration_buffer.num_experiences)
// policy.sequence_length,
1,
)

self.has_updated = False
self.use_recurrent = self.policy.use_recurrent
self.samples_per_update = settings.samples_per_update

def update(self) -> Dict[str, np.ndarray]:
"""
Updates model using buffer.
:param max_batches: The maximum number of batches to use per update.
:return: The loss of the update.
"""
# Don't continue training if the learning rate has reached 0, to reduce training time.

decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
if self.current_lr <= 0:
return {"Losses/Pretraining Loss": 0}

batch_losses = []
possible_demo_batches = (
self.demonstration_buffer.num_experiences // self.n_sequences
)
possible_batches = possible_demo_batches

max_batches = self.samples_per_update // self.n_sequences

n_epoch = self.num_epoch
for _ in range(n_epoch):
self.demonstration_buffer.shuffle(
sequence_length=self.policy.sequence_length
)
if max_batches == 0:
num_batches = possible_batches
else:
num_batches = min(possible_batches, max_batches)
for i in range(num_batches // self.policy.sequence_length):
demo_update_buffer = self.demonstration_buffer
start = i * self.n_sequences * self.policy.sequence_length
end = (i + 1) * self.n_sequences * self.policy.sequence_length
mini_batch_demo = demo_update_buffer.make_mini_batch(start, end)
run_out = self._update_batch(mini_batch_demo, self.n_sequences)
loss = run_out["loss"]
batch_losses.append(loss)

ModelUtils.update_learning_rate(self.optimizer, decay_lr)
self.current_lr = decay_lr

self.has_updated = True
update_stats = {"Losses/Pretraining Loss": np.mean(batch_losses)}
return update_stats

def _behavioral_cloning_loss(self, selected_actions, log_probs, expert_actions):
if self.policy.use_continuous_act:
bc_loss = torch.nn.functional.mse_loss(selected_actions, expert_actions)
else:
log_prob_branches = ModelUtils.break_into_branches(
log_probs, self.policy.act_size
)
bc_loss = torch.mean(
torch.stack(
[
torch.sum(
-torch.nn.functional.log_softmax(log_prob_branch, dim=1)
* expert_actions_branch,
dim=1,
)
for log_prob_branch, expert_actions_branch in zip(
log_prob_branches, expert_actions
)
]
)
)
return bc_loss

def _update_batch(
self, mini_batch_demo: Dict[str, np.ndarray], n_sequences: int
) -> Dict[str, float]:
"""
Helper function for update_batch.
"""
vec_obs = [ModelUtils.list_to_tensor(mini_batch_demo["vector_obs"])]
act_masks = None
if self.policy.use_continuous_act:
expert_actions = ModelUtils.list_to_tensor(mini_batch_demo["actions"])
else:
raw_expert_actions = ModelUtils.list_to_tensor(
mini_batch_demo["actions"], dtype=torch.long
)
expert_actions = ModelUtils.actions_to_onehot(
raw_expert_actions, self.policy.act_size
)
act_masks = ModelUtils.list_to_tensor(
np.ones(
(
self.n_sequences * self.policy.sequence_length,
sum(self.policy.behavior_spec.discrete_action_branches),
),
dtype=np.float32,
)
)

memories = []
if self.policy.use_recurrent:
memories = torch.zeros(
1, self.n_sequences, self.policy.actor_critic.half_mem_size * 2
)

if self.policy.use_vis_obs:
vis_obs = []
for idx, _ in enumerate(
self.policy.actor_critic.network_body.visual_encoders
):
vis_ob = ModelUtils.list_to_tensor(
mini_batch_demo["visual_obs%d" % idx]
)
vis_obs.append(vis_ob)
else:
vis_obs = []

selected_actions, all_log_probs, _, _, _ = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,
memories=memories,
seq_len=self.policy.sequence_length,
all_log_probs=True,
)
bc_loss = self._behavioral_cloning_loss(
selected_actions, all_log_probs, expert_actions
)
self.optimizer.zero_grad()
bc_loss.backward()

self.optimizer.step()
run_out = {"loss": bc_loss.detach().cpu().numpy()}
return run_out