|
1 | 1 | import numpy as np
|
| 2 | +import torch |
| 3 | +from policy.networks import ActorCritic |
2 | 4 |
|
3 | 5 |
|
4 | 6 | class BlackJackAgent:
|
@@ -122,3 +124,75 @@ def update(self, state, action, reward, state_):
|
122 | 124 |
|
123 | 125 | def decrease_eps(self):
|
124 | 126 | self.epsilon = max(0.01, self.epsilon - 1e-5)
|
| 127 | + |
| 128 | + |
| 129 | +class PolicyGradientAgent: |
| 130 | + def __init__(self, input_dim, action_dim, hidden_dim, gamma, lr): |
| 131 | + self.gamma = gamma |
| 132 | + self.policy = ActorCritic(*input_dim, action_dim, hidden_dim) |
| 133 | + self.optimizer = torch.optim.Adam(self.policy.parameters(), lr) |
| 134 | + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| 135 | + self.reward_history, self.action_logprob_history = [], [] |
| 136 | + |
| 137 | + def choose_action(self, state): |
| 138 | + state = torch.from_numpy(state).to(self.device) |
| 139 | + action_proba = torch.softmax(self.policy(state), dim=-1) |
| 140 | + action_dist = torch.distributions.Categorical(action_proba) |
| 141 | + action = action_dist.sample() |
| 142 | + if self.policy.training: |
| 143 | + log_probas = action_dist.log_prob(action) |
| 144 | + self.action_logprob_history.append(log_probas) |
| 145 | + return action.item() |
| 146 | + |
| 147 | + def store_reward(self, reward): |
| 148 | + self.reward_history.append(reward) |
| 149 | + |
| 150 | + def update(self): |
| 151 | + # calculate MC returns & loss |
| 152 | + T = len(self.reward_history) |
| 153 | + discounts = torch.logspace(0, T, steps=T + 1, base=self.gamma, device=self.device)[:T] |
| 154 | + returns = torch.tensor([torch.tensor( |
| 155 | + self.reward_history[t:], dtype=torch.float, device=self.device) @ discounts[t:] for t in range(T)]) |
| 156 | + loss = 0 |
| 157 | + for g, log_prob in zip(returns, self.action_logprob_history): |
| 158 | + loss += - g * log_prob |
| 159 | + |
| 160 | + # sgd + reset history |
| 161 | + self.optimizer.zero_grad() |
| 162 | + loss.backward() |
| 163 | + self.optimizer.step() |
| 164 | + self.reward_history, self.action_logprob_history = [], [] |
| 165 | + |
| 166 | + |
| 167 | +class ActorCriticAgent: |
| 168 | + def __init__(self, input_dim, action_dim, hidden_dim, gamma, lr): |
| 169 | + self.gamma = gamma |
| 170 | + self.actor_critic = ActorCritic(*input_dim, action_dim, hidden_dim) |
| 171 | + self.optimizer = torch.optim.Adam(self.actor_critic.parameters(), lr) |
| 172 | + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| 173 | + self.log_proba, self.value = None, None |
| 174 | + |
| 175 | + |
| 176 | + def choose_action(self, state): |
| 177 | + state = torch.from_numpy(state).to(self.device) |
| 178 | + self.value, action_logits = self.actor_critic(state) |
| 179 | + action_proba = torch.softmax(action_logits, dim=-1) |
| 180 | + action_dist = torch.distributions.Categorical(action_proba) |
| 181 | + action = action_dist.sample() |
| 182 | + self.log_proba = action_dist.log_prob(action) |
| 183 | + return action.item() |
| 184 | + |
| 185 | + def update(self, reward, state_, done): |
| 186 | + # calculate TD loss |
| 187 | + state_ = torch.from_numpy(state_).unsqueeze(0).to(self.device) |
| 188 | + value_, _ = self.actor_critic(state_) |
| 189 | + critic_loss = (reward + self.gamma * value_ * ~done - self.value).pow(2) |
| 190 | + |
| 191 | + # actor loss |
| 192 | + actor_loss = - self.value.detach() * self.log_proba |
| 193 | + |
| 194 | + # sgd + reset history |
| 195 | + loss = critic_loss + actor_loss |
| 196 | + self.optimizer.zero_grad() |
| 197 | + loss.backward() |
| 198 | + self.optimizer.step() |
0 commit comments