Skip to content

Commit 83e17bb

Browse files
authored
Behavioral Cloning Pytorch (#4293)
1 parent 9874a35 commit 83e17bb

File tree

6 files changed

+344
-1
lines changed

6 files changed

+344
-1
lines changed

ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mlagents_envs.base_env import DecisionSteps
55

66
from mlagents.trainers.buffer import AgentBuffer
7-
from mlagents.trainers.components.bc.module import BCModule
7+
from mlagents.trainers.torch.components.bc.module import BCModule
88
from mlagents.trainers.torch.components.reward_providers import create_reward_provider
99

1010
from mlagents.trainers.policy.torch_policy import TorchPolicy
@@ -27,6 +27,14 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
2727
self.global_step = torch.tensor(0)
2828
self.bc_module: Optional[BCModule] = None
2929
self.create_reward_signals(trainer_settings.reward_signals)
30+
if trainer_settings.behavioral_cloning is not None:
31+
self.bc_module = BCModule(
32+
self.policy,
33+
trainer_settings.behavioral_cloning,
34+
policy_learning_rate=trainer_settings.hyperparameters.learning_rate,
35+
default_batch_size=trainer_settings.hyperparameters.batch_size,
36+
default_num_epoch=3,
37+
)
3038

3139
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
3240
pass
Binary file not shown.
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from unittest.mock import MagicMock
2+
import pytest
3+
import mlagents.trainers.tests.mock_brain as mb
4+
5+
import numpy as np
6+
import os
7+
8+
from mlagents.trainers.policy.torch_policy import TorchPolicy
9+
from mlagents.trainers.torch.components.bc.module import BCModule
10+
from mlagents.trainers.settings import (
11+
TrainerSettings,
12+
BehavioralCloningSettings,
13+
NetworkSettings,
14+
)
15+
16+
17+
def create_bc_module(mock_behavior_specs, bc_settings, use_rnn, tanhresample):
18+
# model_path = env.external_brain_names[0]
19+
trainer_config = TrainerSettings()
20+
trainer_config.network_settings.memory = (
21+
NetworkSettings.MemorySettings() if use_rnn else None
22+
)
23+
policy = TorchPolicy(
24+
0,
25+
mock_behavior_specs,
26+
trainer_config,
27+
"test",
28+
False,
29+
tanhresample,
30+
tanhresample,
31+
)
32+
bc_module = BCModule(
33+
policy,
34+
settings=bc_settings,
35+
policy_learning_rate=trainer_config.hyperparameters.learning_rate,
36+
default_batch_size=trainer_config.hyperparameters.batch_size,
37+
default_num_epoch=3,
38+
)
39+
return bc_module
40+
41+
42+
# Test default values
43+
def test_bcmodule_defaults():
44+
# See if default values match
45+
mock_specs = mb.create_mock_3dball_behavior_specs()
46+
bc_settings = BehavioralCloningSettings(
47+
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
48+
)
49+
bc_module = create_bc_module(mock_specs, bc_settings, False, False)
50+
assert bc_module.num_epoch == 3
51+
assert bc_module.batch_size == TrainerSettings().hyperparameters.batch_size
52+
# Assign strange values and see if it overrides properly
53+
bc_settings = BehavioralCloningSettings(
54+
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
55+
num_epoch=100,
56+
batch_size=10000,
57+
)
58+
bc_module = create_bc_module(mock_specs, bc_settings, False, False)
59+
assert bc_module.num_epoch == 100
60+
assert bc_module.batch_size == 10000
61+
62+
63+
# Test with continuous control env and vector actions
64+
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
65+
def test_bcmodule_update(is_sac):
66+
mock_specs = mb.create_mock_3dball_behavior_specs()
67+
bc_settings = BehavioralCloningSettings(
68+
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
69+
)
70+
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
71+
stats = bc_module.update()
72+
for _, item in stats.items():
73+
assert isinstance(item, np.float32)
74+
75+
76+
# Test with constant pretraining learning rate
77+
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
78+
def test_bcmodule_constant_lr_update(is_sac):
79+
mock_specs = mb.create_mock_3dball_behavior_specs()
80+
bc_settings = BehavioralCloningSettings(
81+
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
82+
steps=0,
83+
)
84+
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
85+
stats = bc_module.update()
86+
for _, item in stats.items():
87+
assert isinstance(item, np.float32)
88+
old_learning_rate = bc_module.current_lr
89+
90+
_ = bc_module.update()
91+
assert old_learning_rate == bc_module.current_lr
92+
93+
94+
# Test with constant pretraining learning rate
95+
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
96+
def test_bcmodule_linear_lr_update(is_sac):
97+
mock_specs = mb.create_mock_3dball_behavior_specs()
98+
bc_settings = BehavioralCloningSettings(
99+
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
100+
steps=100,
101+
)
102+
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
103+
# Should decay by 10/100 * 0.0003 = 0.00003
104+
bc_module.policy.get_current_step = MagicMock(return_value=10)
105+
old_learning_rate = bc_module.current_lr
106+
_ = bc_module.update()
107+
assert old_learning_rate - 0.00003 == pytest.approx(bc_module.current_lr, abs=0.01)
108+
109+
110+
# Test with RNN
111+
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
112+
def test_bcmodule_rnn_update(is_sac):
113+
mock_specs = mb.create_mock_3dball_behavior_specs()
114+
bc_settings = BehavioralCloningSettings(
115+
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
116+
)
117+
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
118+
stats = bc_module.update()
119+
for _, item in stats.items():
120+
assert isinstance(item, np.float32)
121+
122+
123+
# Test with discrete control and visual observations
124+
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
125+
def test_bcmodule_dc_visual_update(is_sac):
126+
mock_specs = mb.create_mock_banana_behavior_specs()
127+
bc_settings = BehavioralCloningSettings(
128+
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo"
129+
)
130+
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
131+
stats = bc_module.update()
132+
for _, item in stats.items():
133+
assert isinstance(item, np.float32)
134+
135+
136+
# Test with discrete control, visual observations and RNN
137+
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
138+
def test_bcmodule_rnn_dc_update(is_sac):
139+
mock_specs = mb.create_mock_banana_behavior_specs()
140+
bc_settings = BehavioralCloningSettings(
141+
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo"
142+
)
143+
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
144+
stats = bc_module.update()
145+
for _, item in stats.items():
146+
assert isinstance(item, np.float32)
147+
148+
149+
if __name__ == "__main__":
150+
pytest.main()
Binary file not shown.

ml-agents/mlagents/trainers/torch/components/bc/__init__.py

Whitespace-only changes.
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from typing import Dict
2+
import numpy as np
3+
import torch
4+
5+
from mlagents.trainers.policy.torch_policy import TorchPolicy
6+
from mlagents.trainers.demo_loader import demo_to_buffer
7+
from mlagents.trainers.settings import BehavioralCloningSettings, ScheduleType
8+
from mlagents.trainers.torch.utils import ModelUtils
9+
10+
11+
class BCModule:
12+
def __init__(
13+
self,
14+
policy: TorchPolicy,
15+
settings: BehavioralCloningSettings,
16+
policy_learning_rate: float,
17+
default_batch_size: int,
18+
default_num_epoch: int,
19+
):
20+
"""
21+
A BC trainer that can be used inline with RL.
22+
:param policy: The policy of the learning model
23+
:param settings: The settings for BehavioralCloning including LR strength, batch_size,
24+
num_epochs, samples_per_update and LR annealing steps.
25+
:param policy_learning_rate: The initial Learning Rate of the policy. Used to set an appropriate learning rate
26+
for the pretrainer.
27+
"""
28+
self.policy = policy
29+
self._anneal_steps = settings.steps
30+
self.current_lr = policy_learning_rate * settings.strength
31+
32+
learning_rate_schedule: ScheduleType = ScheduleType.LINEAR if self._anneal_steps > 0 else ScheduleType.CONSTANT
33+
self.decay_learning_rate = ModelUtils.DecayedValue(
34+
learning_rate_schedule, self.current_lr, 1e-10, self._anneal_steps
35+
)
36+
params = self.policy.actor_critic.parameters()
37+
self.optimizer = torch.optim.Adam(params, lr=self.current_lr)
38+
_, self.demonstration_buffer = demo_to_buffer(
39+
settings.demo_path, policy.sequence_length, policy.behavior_spec
40+
)
41+
42+
self.batch_size = (
43+
settings.batch_size if settings.batch_size else default_batch_size
44+
)
45+
self.num_epoch = settings.num_epoch if settings.num_epoch else default_num_epoch
46+
self.n_sequences = max(
47+
min(self.batch_size, self.demonstration_buffer.num_experiences)
48+
// policy.sequence_length,
49+
1,
50+
)
51+
52+
self.has_updated = False
53+
self.use_recurrent = self.policy.use_recurrent
54+
self.samples_per_update = settings.samples_per_update
55+
56+
def update(self) -> Dict[str, np.ndarray]:
57+
"""
58+
Updates model using buffer.
59+
:param max_batches: The maximum number of batches to use per update.
60+
:return: The loss of the update.
61+
"""
62+
# Don't continue training if the learning rate has reached 0, to reduce training time.
63+
64+
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
65+
if self.current_lr <= 0:
66+
return {"Losses/Pretraining Loss": 0}
67+
68+
batch_losses = []
69+
possible_demo_batches = (
70+
self.demonstration_buffer.num_experiences // self.n_sequences
71+
)
72+
possible_batches = possible_demo_batches
73+
74+
max_batches = self.samples_per_update // self.n_sequences
75+
76+
n_epoch = self.num_epoch
77+
for _ in range(n_epoch):
78+
self.demonstration_buffer.shuffle(
79+
sequence_length=self.policy.sequence_length
80+
)
81+
if max_batches == 0:
82+
num_batches = possible_batches
83+
else:
84+
num_batches = min(possible_batches, max_batches)
85+
for i in range(num_batches // self.policy.sequence_length):
86+
demo_update_buffer = self.demonstration_buffer
87+
start = i * self.n_sequences * self.policy.sequence_length
88+
end = (i + 1) * self.n_sequences * self.policy.sequence_length
89+
mini_batch_demo = demo_update_buffer.make_mini_batch(start, end)
90+
run_out = self._update_batch(mini_batch_demo, self.n_sequences)
91+
loss = run_out["loss"]
92+
batch_losses.append(loss)
93+
94+
ModelUtils.update_learning_rate(self.optimizer, decay_lr)
95+
self.current_lr = decay_lr
96+
97+
self.has_updated = True
98+
update_stats = {"Losses/Pretraining Loss": np.mean(batch_losses)}
99+
return update_stats
100+
101+
def _behavioral_cloning_loss(self, selected_actions, log_probs, expert_actions):
102+
if self.policy.use_continuous_act:
103+
bc_loss = torch.nn.functional.mse_loss(selected_actions, expert_actions)
104+
else:
105+
log_prob_branches = ModelUtils.break_into_branches(
106+
log_probs, self.policy.act_size
107+
)
108+
bc_loss = torch.mean(
109+
torch.stack(
110+
[
111+
torch.sum(
112+
-torch.nn.functional.log_softmax(log_prob_branch, dim=1)
113+
* expert_actions_branch,
114+
dim=1,
115+
)
116+
for log_prob_branch, expert_actions_branch in zip(
117+
log_prob_branches, expert_actions
118+
)
119+
]
120+
)
121+
)
122+
return bc_loss
123+
124+
def _update_batch(
125+
self, mini_batch_demo: Dict[str, np.ndarray], n_sequences: int
126+
) -> Dict[str, float]:
127+
"""
128+
Helper function for update_batch.
129+
"""
130+
vec_obs = [ModelUtils.list_to_tensor(mini_batch_demo["vector_obs"])]
131+
act_masks = None
132+
if self.policy.use_continuous_act:
133+
expert_actions = ModelUtils.list_to_tensor(mini_batch_demo["actions"])
134+
else:
135+
raw_expert_actions = ModelUtils.list_to_tensor(
136+
mini_batch_demo["actions"], dtype=torch.long
137+
)
138+
expert_actions = ModelUtils.actions_to_onehot(
139+
raw_expert_actions, self.policy.act_size
140+
)
141+
act_masks = ModelUtils.list_to_tensor(
142+
np.ones(
143+
(
144+
self.n_sequences * self.policy.sequence_length,
145+
sum(self.policy.behavior_spec.discrete_action_branches),
146+
),
147+
dtype=np.float32,
148+
)
149+
)
150+
151+
memories = []
152+
if self.policy.use_recurrent:
153+
memories = torch.zeros(
154+
1, self.n_sequences, self.policy.actor_critic.half_mem_size * 2
155+
)
156+
157+
if self.policy.use_vis_obs:
158+
vis_obs = []
159+
for idx, _ in enumerate(
160+
self.policy.actor_critic.network_body.visual_encoders
161+
):
162+
vis_ob = ModelUtils.list_to_tensor(
163+
mini_batch_demo["visual_obs%d" % idx]
164+
)
165+
vis_obs.append(vis_ob)
166+
else:
167+
vis_obs = []
168+
169+
selected_actions, all_log_probs, _, _, _ = self.policy.sample_actions(
170+
vec_obs,
171+
vis_obs,
172+
masks=act_masks,
173+
memories=memories,
174+
seq_len=self.policy.sequence_length,
175+
all_log_probs=True,
176+
)
177+
bc_loss = self._behavioral_cloning_loss(
178+
selected_actions, all_log_probs, expert_actions
179+
)
180+
self.optimizer.zero_grad()
181+
bc_loss.backward()
182+
183+
self.optimizer.step()
184+
run_out = {"loss": bc_loss.detach().cpu().numpy()}
185+
return run_out

0 commit comments

Comments
 (0)