forked from datawhalechina/easy-rl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8d06642
commit 66fd8ef
Showing
7 changed files
with
255 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
## 思路 | ||
|
||
见[我的博客](https://blog.csdn.net/JohnJim0/article/details/109557173) | ||
## 环境 | ||
|
||
python 3.7.9 | ||
|
||
pytorch 1.6.0 | ||
|
||
tensorboard 2.3.0 | ||
|
||
torchvision 0.7.0 | ||
|
||
## 使用 | ||
|
||
train: | ||
|
||
```python | ||
python main.py | ||
``` | ||
|
||
eval: | ||
|
||
```python | ||
python main.py --train 0 | ||
``` | ||
可视化: | ||
```python | ||
tensorboard --logdir logs | ||
``` | ||
|
||
## Torch知识 | ||
|
||
[with torch.no_grad()](https://www.jianshu.com/p/1cea017f5d11) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
''' | ||
@Author: John | ||
@Email: johnjim0816@gmail.com | ||
@Date: 2020-06-12 00:50:49 | ||
@LastEditor: John | ||
LastEditTime: 2021-03-13 14:56:23 | ||
@Discription: | ||
@Environment: python 3.7.7 | ||
''' | ||
'''off-policy | ||
''' | ||
|
||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import random | ||
import math | ||
import numpy as np | ||
from common.memory import ReplayBuffer | ||
from common.model import MLP2 | ||
class DQN: | ||
def __init__(self, n_states, n_actions, cfg): | ||
|
||
self.n_actions = n_actions # 总的动作个数 | ||
self.device = cfg.device # 设备,cpu或gpu等 | ||
self.gamma = cfg.gamma # 奖励的折扣因子 | ||
# e-greedy策略相关参数 | ||
self.sample_count = 0 # 用于epsilon的衰减计数 | ||
self.epsilon = 0 | ||
self.epsilon_start = cfg.epsilon_start | ||
self.epsilon_end = cfg.epsilon_end | ||
self.epsilon_decay = cfg.epsilon_decay | ||
self.batch_size = cfg.batch_size | ||
self.policy_net = MLP2(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device) | ||
self.target_net = MLP2(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device) | ||
# target_net的初始模型参数完全复制policy_net | ||
self.target_net.load_state_dict(self.policy_net.state_dict()) | ||
self.target_net.eval() # 不启用 BatchNormalization 和 Dropout | ||
# 可查parameters()与state_dict()的区别,前者require_grad=True | ||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr) | ||
self.loss = 0 | ||
self.memory = ReplayBuffer(cfg.memory_capacity) | ||
|
||
def choose_action(self, state, train=True): | ||
'''选择动作 | ||
''' | ||
if train: | ||
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \ | ||
math.exp(-1. * self.sample_count / self.epsilon_decay) | ||
self.sample_count += 1 | ||
if random.random() > self.epsilon: | ||
with torch.no_grad(): | ||
# 先转为张量便于丢给神经网络,state元素数据原本为float64 | ||
# 注意state=torch.tensor(state).unsqueeze(0)跟state=torch.tensor([state])等价 | ||
state = torch.tensor( | ||
[state], device=self.device, dtype=torch.float32) | ||
# 如tensor([[-0.0798, -0.0079]], grad_fn=<AddmmBackward>) | ||
q_value = self.policy_net(state) | ||
# tensor.max(1)返回每行的最大值以及对应的下标, | ||
# 如torch.return_types.max(values=tensor([10.3587]),indices=tensor([0])) | ||
# 所以tensor.max(1)[1]返回最大值对应的下标,即action | ||
action = q_value.max(1)[1].item() | ||
else: | ||
action = random.randrange(self.n_actions) | ||
return action | ||
else: | ||
with torch.no_grad(): # 取消保存梯度 | ||
# 先转为张量便于丢给神经网络,state元素数据原本为float64 | ||
# 注意state=torch.tensor(state).unsqueeze(0)跟state=torch.tensor([state])等价 | ||
state = torch.tensor( | ||
[state], device='cpu', dtype=torch.float32) # 如tensor([[-0.0798, -0.0079]], grad_fn=<AddmmBackward>) | ||
q_value = self.target_net(state) | ||
# tensor.max(1)返回每行的最大值以及对应的下标, | ||
# 如torch.return_types.max(values=tensor([10.3587]),indices=tensor([0])) | ||
# 所以tensor.max(1)[1]返回最大值对应的下标,即action | ||
action = q_value.max(1)[1].item() | ||
return action | ||
def update(self): | ||
|
||
if len(self.memory) < self.batch_size: | ||
return | ||
# 从memory中随机采样transition | ||
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample( | ||
self.batch_size) | ||
'''转为张量 | ||
例如tensor([[-4.5543e-02, -2.3910e-01, 1.8344e-02, 2.3158e-01],...,[-1.8615e-02, -2.3921e-01, -1.1791e-02, 2.3400e-01]])''' | ||
state_batch = torch.tensor( | ||
state_batch, device=self.device, dtype=torch.float) | ||
action_batch = torch.tensor(action_batch, device=self.device).unsqueeze( | ||
1) # 例如tensor([[1],...,[0]]) | ||
reward_batch = torch.tensor( | ||
reward_batch, device=self.device, dtype=torch.float) # tensor([1., 1.,...,1]) | ||
next_state_batch = torch.tensor( | ||
next_state_batch, device=self.device, dtype=torch.float) | ||
done_batch = torch.tensor(np.float32( | ||
done_batch), device=self.device).unsqueeze(1) # 将bool转为float然后转为张量 | ||
|
||
'''计算当前(s_t,a)对应的Q(s_t, a)''' | ||
'''torch.gather:对于a=torch.Tensor([[1,2],[3,4]]),那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])''' | ||
q_values = self.policy_net(state_batch).gather( | ||
dim=1, index=action_batch) # 等价于self.forward | ||
# 计算所有next states的V(s_{t+1}),即通过target_net中选取reward最大的对应states | ||
next_state_values = self.target_net( | ||
next_state_batch).max(1)[0].detach() # 比如tensor([ 0.0060, -0.0171,...,]) | ||
# 计算 expected_q_value | ||
# 对于终止状态,此时done_batch[0]=1, 对应的expected_q_value等于reward | ||
expected_q_values = reward_batch + self.gamma * \ | ||
next_state_values * (1-done_batch[0]) | ||
# self.loss = F.smooth_l1_loss(q_values,expected_q_values.unsqueeze(1)) # 计算 Huber loss | ||
self.loss = nn.MSELoss()(q_values, expected_q_values.unsqueeze(1)) # 计算 均方误差loss | ||
# 优化模型 | ||
self.optimizer.zero_grad() # zero_grad清除上一步所有旧的gradients from the last step | ||
# loss.backward()使用backpropagation计算loss相对于所有parameters(需要gradients)的微分 | ||
self.loss.backward() | ||
for param in self.policy_net.parameters(): # clip防止梯度爆炸 | ||
param.grad.data.clamp_(-1, 1) | ||
|
||
self.optimizer.step() # 更新模型 | ||
|
||
def save(self,path): | ||
torch.save(self.target_net.state_dict(), path+'dqn_checkpoint.pth') | ||
|
||
def load(self,path): | ||
self.target_net.load_state_dict(torch.load(path+'dqn_checkpoint.pth')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
''' | ||
@Author: John | ||
@Email: johnjim0816@gmail.com | ||
@Date: 2020-06-12 00:48:57 | ||
@LastEditor: John | ||
LastEditTime: 2021-03-13 14:56:50 | ||
@Discription: | ||
@Environment: python 3.7.7 | ||
''' | ||
import sys,os | ||
sys.path.append(os.getcwd()) # 添加当前终端路径 | ||
import gym | ||
import torch | ||
import datetime | ||
from DQN.agent import DQN | ||
from common.plot import plot_rewards | ||
from common.utils import save_results | ||
|
||
SEQUENCE = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间 | ||
SAVED_MODEL_PATH = os.path.split(os.path.abspath(__file__))[0]+"/saved_model/"+SEQUENCE+'/' # 生成保存的模型路径 | ||
if not os.path.exists(os.path.split(os.path.abspath(__file__))[0]+"/saved_model/"): # 检测是否存在文件夹 | ||
os.mkdir(os.path.split(os.path.abspath(__file__))[0]+"/saved_model/") | ||
if not os.path.exists(SAVED_MODEL_PATH): # 检测是否存在文件夹 | ||
os.mkdir(SAVED_MODEL_PATH) | ||
RESULT_PATH = os.path.split(os.path.abspath(__file__))[0]+"/results/"+SEQUENCE+'/' # 存储reward的路径 | ||
if not os.path.exists(os.path.split(os.path.abspath(__file__))[0]+"/results/"): # 检测是否存在文件夹 | ||
os.mkdir(os.path.split(os.path.abspath(__file__))[0]+"/results/") | ||
if not os.path.exists(RESULT_PATH): # 检测是否存在文件夹 | ||
os.mkdir(RESULT_PATH) | ||
|
||
class DQNConfig: | ||
def __init__(self): | ||
self.algo = "DQN" # 算法名称 | ||
self.gamma = 0.99 | ||
self.epsilon_start = 0.95 # e-greedy策略的初始epsilon | ||
self.epsilon_end = 0.01 | ||
self.epsilon_decay = 200 | ||
self.lr = 0.01 # 学习率 | ||
self.memory_capacity = 800 # Replay Memory容量 | ||
self.batch_size = 64 | ||
self.train_eps = 250 # 训练的episode数目 | ||
self.train_steps = 200 # 训练每个episode的最大长度 | ||
self.target_update = 2 # target net的更新频率 | ||
self.eval_eps = 20 # 测试的episode数目 | ||
self.eval_steps = 200 # 测试每个episode的最大长度 | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测gpu | ||
self.hidden_dim = 128 # 神经网络隐藏层维度 | ||
|
||
def train(cfg,env,agent): | ||
print('Start to train !') | ||
rewards = [] | ||
ma_rewards = [] # 滑动平均的reward | ||
ep_steps = [] | ||
for i_episode in range(cfg.train_eps): | ||
state = env.reset() # reset环境状态 | ||
ep_reward = 0 | ||
for i_step in range(cfg.train_steps): | ||
action = agent.choose_action(state) # 根据当前环境state选择action | ||
next_state, reward, done, _ = env.step(action) # 更新环境参数 | ||
ep_reward += reward | ||
agent.memory.push(state, action, reward, next_state, done) # 将state等这些transition存入memory | ||
state = next_state # 跳转到下一个状态 | ||
agent.update() # 每步更新网络 | ||
if done: | ||
break | ||
# 更新target network,复制DQN中的所有weights and biases | ||
if i_episode % cfg.target_update == 0: | ||
agent.target_net.load_state_dict(agent.policy_net.state_dict()) | ||
print('Episode:{}/{}, Reward:{}, Steps:{}, Done:{}'.format(i_episode+1,cfg.train_eps,ep_reward,i_step,done)) | ||
ep_steps.append(i_step) | ||
rewards.append(ep_reward) | ||
# 计算滑动窗口的reward | ||
if ma_rewards: | ||
ma_rewards.append( | ||
0.9*ma_rewards[-1]+0.1*ep_reward) | ||
else: | ||
ma_rewards.append(ep_reward) | ||
print('Complete training!') | ||
return rewards,ma_rewards | ||
|
||
if __name__ == "__main__": | ||
cfg = DQNConfig() | ||
env = gym.make('CartPole-v0').unwrapped # 可google为什么unwrapped gym,此处一般不需要 | ||
env.seed(1) # 设置env随机种子 | ||
n_states = env.observation_space.shape[0] | ||
n_actions = env.action_space.n | ||
agent = DQN(n_states,n_actions,cfg) | ||
rewards,ma_rewards = train(cfg,env,agent) | ||
agent.save(path=SAVED_MODEL_PATH) | ||
save_results(rewards,ma_rewards,tag='train',path=RESULT_PATH) | ||
plot_rewards(rewards,ma_rewards,tag="train",algo = cfg.algo,path=RESULT_PATH) |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.