|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from config import gamma, truncation_clip, delta, max_gradient_norm, trust_region_decay |
| 7 | + |
| 8 | +class Model(nn.Module): |
| 9 | + def __init__(self, num_inputs, num_outputs): |
| 10 | + super(Model, self).__init__() |
| 11 | + self.num_inputs = num_inputs |
| 12 | + self.num_outputs = num_outputs |
| 13 | + |
| 14 | + self.fc = nn.Linear(num_inputs, 128) |
| 15 | + self.fc_actor = nn.Linear(128, num_outputs) |
| 16 | + self.fc_critic = nn.Linear(128, num_outputs) |
| 17 | + |
| 18 | + for m in self.modules(): |
| 19 | + if isinstance(m, nn.Linear): |
| 20 | + nn.init.xavier_uniform(m.weight) |
| 21 | + |
| 22 | + def forward(self, input): |
| 23 | + x = F.relu(self.fc(input)) |
| 24 | + policy = F.softmax(self.fc_actor(x), dim=1) |
| 25 | + q_value = self.fc_critic(x) |
| 26 | + value = (policy * q_value).sum(-1, keepdim=True).view(-1) |
| 27 | + return policy, q_value, value |
| 28 | + |
| 29 | +class LocalModel(Model): |
| 30 | + def __init__(self, num_inputs, num_outputs): |
| 31 | + super(LocalModel, self).__init__(num_inputs, num_outputs) |
| 32 | + |
| 33 | + def pull_from_global_model(self, global_model): |
| 34 | + self.load_state_dict(global_model.state_dict()) |
| 35 | + |
| 36 | + def update_model(self, loss, global_optimizer, global_model, global_average_model): |
| 37 | + global_optimizer.zero_grad() |
| 38 | + loss.backward() |
| 39 | + # nn.utils.clip_grad_norm_(self.parameters(), max_gradient_norm) |
| 40 | + |
| 41 | + for lp, gp in zip(self.parameters(), global_model.parameters()): |
| 42 | + if gp.grad is not None: |
| 43 | + return |
| 44 | + gp.grad = lp.grad |
| 45 | + |
| 46 | + global_optimizer.step() |
| 47 | + |
| 48 | + for gp, gap in zip(global_model.parameters(), global_average_model.parameters()): |
| 49 | + gap = trust_region_decay * gap + (1 - trust_region_decay) * gp |
| 50 | + |
| 51 | + |
| 52 | + def compute_q_retraces(self, rewards, masks, values, q_actions, rho_actions, next_value): |
| 53 | + q_retraces = torch.zeros(rewards.size()) |
| 54 | + q_retraces[-1] = next_value |
| 55 | + |
| 56 | + q_ret = q_retraces[-1] |
| 57 | + for step in reversed(range(len(rewards) - 1)): |
| 58 | + q_ret = rewards[step] + gamma * q_ret |
| 59 | + q_retraces[step] = q_ret |
| 60 | + q_ret = rho_actions[step] * (q_ret - q_actions[step]) + values[step] |
| 61 | + |
| 62 | + return q_retraces |
| 63 | + |
| 64 | + |
| 65 | + def get_loss(self, on_policy, trajectory, average_model): |
| 66 | + states, next_states, actions, rewards, masks, old_policies = trajectory |
| 67 | + states = torch.stack(states) |
| 68 | + next_states = torch.stack(next_states) |
| 69 | + actions = torch.Tensor(actions).long().view(-1,1) |
| 70 | + rewards = torch.Tensor(rewards) |
| 71 | + masks = torch.Tensor(masks) |
| 72 | + old_policies = torch.stack(old_policies) |
| 73 | + |
| 74 | + states = states.view(-1, self.num_inputs) |
| 75 | + next_states = next_states.view(-1, self.num_inputs) |
| 76 | + policies, Qs, Vs = self.forward(states) |
| 77 | + |
| 78 | + Q_actions = Qs.gather(1, actions).view(-1) |
| 79 | + |
| 80 | + if not on_policy: |
| 81 | + rhos = policies / old_policies |
| 82 | + else: |
| 83 | + rhos = torch.zeros(policies.size()).fill_(1) |
| 84 | + |
| 85 | + rho_actions = rhos.gather(1, actions).view(-1) |
| 86 | + |
| 87 | + if masks[-1] == 0: |
| 88 | + Qret = 0 |
| 89 | + else: |
| 90 | + Qret = Vs[-1] |
| 91 | + Qrets = self.compute_q_retraces(rewards, masks, Vs, Q_actions, rho_actions, Qret) |
| 92 | + log_policy = torch.log(policies) |
| 93 | + log_policy_action = log_policy.gather(1, actions).view(-1) |
| 94 | + |
| 95 | + actor_loss_1 = - (log_policy_action * ( |
| 96 | + rho_actions.clamp(max=truncation_clip) * (Qrets - Vs) |
| 97 | + ).detach()).mean() |
| 98 | + actor_loss_2 = - (log_policy * ( |
| 99 | + (1 - truncation_clip / rhos).clamp(min=0) * policies * (Qs - Vs.view(-1,1).expand_as(Qs)) |
| 100 | + ).detach()).sum(1).mean() |
| 101 | + actor_loss = actor_loss_1 + actor_loss_2 |
| 102 | + |
| 103 | + value_loss = ((Qret - Q_actions) ** 2).mean() |
| 104 | + |
| 105 | + |
| 106 | + g_1 = ((1 / log_policy_action) * ( |
| 107 | + rho_actions.clamp(max=truncation_clip) * (Qrets - Vs) |
| 108 | + )) |
| 109 | + g_2 = ((1 / log_policy) * ( |
| 110 | + (1 - truncation_clip / rhos).clamp(min=0) * policies * (Qs - Vs.view(-1,1).expand_as(Qs)) |
| 111 | + )).sum(1) |
| 112 | + g = (g_1 + g_2).detach() |
| 113 | + average_policies, _, _ = average_model(states) |
| 114 | + k = (average_policies / policies).gather(1, actions).view(-1) |
| 115 | + |
| 116 | + kl = (average_policies * torch.log(average_policies / policies)).sum(1).mean(0) |
| 117 | + |
| 118 | + |
| 119 | + k_dot_g = (k * g).sum() |
| 120 | + k_dot_k = (k * k).sum() |
| 121 | + |
| 122 | + adj = ((k_dot_g - delta) / k_dot_k).clamp(min=0).detach() |
| 123 | + trust_region_actor_loss = actor_loss + adj * kl |
| 124 | + |
| 125 | + loss = trust_region_actor_loss + value_loss |
| 126 | + |
| 127 | + return loss |
| 128 | + |
| 129 | + def get_action(self, input): |
| 130 | + policy, _, _ = self.forward(input) |
| 131 | + policy = policy[0].data.numpy() |
| 132 | + |
| 133 | + action = np.random.choice(self.num_outputs, 1, p=policy)[0] |
| 134 | + return action, policy |
0 commit comments