本项目用于学习RL基础算法,尽量做到: 注释详细,结构清晰。
代码结构主要分为以下几个脚本:
model.py
强化学习算法的基本模型,比如神经网络,actor,critic等memory.py
保存Replay Buffer,用于off-policyplot.py
利用matplotlib或seaborn绘制rewards图,包括滑动平均的reward,结果保存在result文件夹中env.py
用于构建强化学习环境,也可以重新自定义环境,比如给action加noiseagent.py
RL核心算法,比如dqn等,主要包含update和choose_action两个方法,train.py
保存用于训练和测试的函数
其中model.py
,memory.py
,plot.py
由于不同算法都会用到,所以放入common
文件夹中。
注意:新版本中将model
,memory
相关内容全部放到了agent.py
里面,plot
放到了common.utils
中。
python 3.7、pytorch 1.6.0-1.8.1、gym 0.21.0
直接运行带有train
的py文件或ipynb文件会进行训练默认的任务;
也可以运行带有task
的py文件训练不同的任务