forked from devendrachaplot/Neural-SLAM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ppo.py
108 lines (85 loc) · 3.94 KB
/
ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# The following code is largely borrowed from:
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/algo/ppo.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class PPO():
def __init__(
self,
actor_critic,
clip_param,
ppo_epoch,
num_mini_batch,
value_loss_coef,
entropy_coef,
lr=None,
eps=None,
max_grad_norm=None,
use_clipped_value_loss=True):
self.actor_critic = actor_critic
self.clip_param = clip_param
self.ppo_epoch = ppo_epoch
self.num_mini_batch = num_mini_batch
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.use_clipped_value_loss = use_clipped_value_loss
self.optimizer = optim.Adam(filter(lambda p: p.requires_grad,
actor_critic.parameters()), lr=lr, eps=eps)
def update(self, rollouts):
advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
advantages = (advantages - advantages.mean()) / (
advantages.std() + 1e-5)
value_loss_epoch = 0
action_loss_epoch = 0
dist_entropy_epoch = 0
for e in range(self.ppo_epoch):
if self.actor_critic.is_recurrent:
data_generator = rollouts.recurrent_generator(
advantages, self.num_mini_batch)
else:
data_generator = rollouts.feed_forward_generator(
advantages, self.num_mini_batch)
for sample in data_generator:
value_preds = sample['value_preds']
returns = sample['returns']
adv_targ = sample['adv_targ']
# Reshape to do in a single forward pass for all steps
values, action_log_probs, dist_entropy, _ = \
self.actor_critic.evaluate_actions(
sample['obs'], sample['rec_states'],
sample['masks'], sample['actions'],
extras=sample['extras']
)
ratio = torch.exp(action_log_probs -
sample['old_action_log_probs'])
surr1 = ratio * adv_targ
surr2 = torch.clamp(ratio, 1.0 - self.clip_param,
1.0 + self.clip_param) * adv_targ
action_loss = -torch.min(surr1, surr2).mean()
if self.use_clipped_value_loss:
value_pred_clipped = value_preds + \
(values - value_preds).clamp(
-self.clip_param, self.clip_param)
value_losses = (values - returns).pow(2)
value_losses_clipped = (value_pred_clipped
- returns).pow(2)
value_loss = .5 * torch.max(value_losses,
value_losses_clipped).mean()
else:
value_loss = 0.5 * (returns - values).pow(2).mean()
self.optimizer.zero_grad()
(value_loss * self.value_loss_coef + action_loss -
dist_entropy * self.entropy_coef).backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(),
self.max_grad_norm)
self.optimizer.step()
value_loss_epoch += value_loss.item()
action_loss_epoch += action_loss.item()
dist_entropy_epoch += dist_entropy.item()
num_updates = self.ppo_epoch * self.num_mini_batch
value_loss_epoch /= num_updates
action_loss_epoch /= num_updates
dist_entropy_epoch /= num_updates
return value_loss_epoch, action_loss_epoch, dist_entropy_epoch