Skip to content
This repository was archived by the owner on Oct 19, 2023. It is now read-only.

update dqn example #10

Merged
merged 1 commit into from
Sep 10, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,13 @@
"outputs": [],
"source": [
"import gym\n",
"from gym import wrappers\n",
"import random\n",
"import math\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.autograd import Variable\n",
"import torch.nn.functional as F\n",
"import matplotlib.pyplot as plt\n",
"from collections import deque\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAI Gym을 이용하여 게임환경 구축하기\n",
"\n",
"\n",
"강화학습 예제들을 보면 항상 게임과 연관되어 있습니다. 원래 우리가 궁극적으로 원하는 목표는 어디서든 적응할 수 있는 인공지능이지만, 너무 복잡한 문제이기도 하고 가상 환경을 설계하기도 어렵기 때문에 일단 게임이라는 환경을 사용해 하는 것입니다.\n",
"\n",
"대부분의 게임은 점수 혹은 목표가 있습니다. 점수가 오르거나 목표에 도달하면 일종의 리워드를 받고 원치 않은 행동을 할때는 마이너스 리워드를 주는 경우도 있습니다. 아까 비유를 들었던 달리기를 배울때의 경우를 예로 들면 총 나아간 길이 혹은 목표 도착지 도착 여부로 리워드를 주고 넘어질때 패널티를 줄 수 있을 것입니다. \n",
"\n",
"게임중에서도 가장 간단한 카트폴이라는 환경을 구축하여 강화학습을 배울 토대를 마련해보겠습니다."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"env = gym.make('CartPole-v1')"
"from collections import deque"
]
},
{
Expand All @@ -60,18 +33,18 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# 하이퍼파라미터\n",
"EPISODES = 50 # 에피소드 반복 횟수\n",
"EPS_START = 0.9 # e-greedy threshold 시작 값\n",
"EPS_END = 0.05 # e-greedy threshold 최종 값\n",
"EPS_DECAY = 200 # e-greedy threshold decay\n",
"GAMMA = 0.8 # \n",
"LR = 0.001 # NN optimizer learning rate\n",
"BATCH_SIZE = 64 # Q-learning batch size"
"EPISODES = 50 # 애피소드 반복횟수\n",
"EPS_START = 0.9 # 학습 시작시 에이전트가 무작위로 행동할 확률\n",
"EPS_END = 0.05 # 학습 막바지에 에이전트가 무작위로 행동할 확률\n",
"EPS_DECAY = 200 # 학습 진행시 에이전트가 무작위로 행동할 확률을 감소시키는 값\n",
"GAMMA = 0.8 # 할인계수\n",
"LR = 0.001 # 학습률\n",
"BATCH_SIZE = 64 # 배치 크기"
]
},
{
Expand All @@ -94,9 +67,15 @@
" nn.ReLU(),\n",
" nn.Linear(256, 2)\n",
" )\n",
" self.memory = deque(maxlen=10000)\n",
" self.optimizer = optim.Adam(self.model.parameters(), LR)\n",
" self.steps_done = 0\n",
" self.memory = deque(maxlen=10000)\n",
"\n",
" def memorize(self, state, action, reward, next_state):\n",
" self.memory.append((state,\n",
" action,\n",
" torch.FloatTensor([reward]),\n",
" torch.FloatTensor([next_state])))\n",
" \n",
" def act(self, state):\n",
" eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * self.steps_done / EPS_DECAY)\n",
Expand All @@ -105,12 +84,6 @@
" return self.model(state).data.max(1)[1].view(1, 1)\n",
" else:\n",
" return torch.LongTensor([[random.randrange(2)]])\n",
"\n",
" def memorize(self, state, action, reward, next_state):\n",
" self.memory.append((state,\n",
" action,\n",
" torch.FloatTensor([reward]),\n",
" torch.FloatTensor([next_state])))\n",
" \n",
" def learn(self):\n",
" \"\"\"Experience Replay\"\"\"\n",
Expand Down Expand Up @@ -140,10 +113,9 @@
"source": [
"## 학습 준비하기\n",
"\n",
"드디어 만들어둔 DQNAgent를 인스턴스화 합니다.\n",
"그리고 `gym`을 이용하여 `CartPole-v0`환경도 준비합니다.\n",
"자, 이제 `agent` 객체를 이용하여 `CartPole-v0` 환경과 상호작용을 통해 게임을 배우도록 하겠습니다.\n",
"학습 진행을 기록하기 위해 `score_history` 리스트를 이용하여 점수를 저장하겠습니다."
"`gym`을 이용하여 `CartPole-v0`환경을 준비하고 앞서 만들어둔 DQNAgent를 agent로 인스턴스화 합니다.\n",
"\n",
"자, 이제 `agent` 객체를 이용하여 `CartPole-v0` 환경과 상호작용을 통해 게임을 배우도록 하겠습니다."
]
},
{
Expand All @@ -152,8 +124,8 @@
"metadata": {},
"outputs": [],
"source": [
"agent = DQNAgent()\n",
"env = gym.make('CartPole-v0')\n",
"agent = DQNAgent()\n",
"score_history = []"
]
},
Expand All @@ -173,56 +145,56 @@
"name": "stdout",
"output_type": "stream",
"text": [
"에피소드:1 점수: 11\n",
"에피소드:2 점수: 32\n",
"에피소드:3 점수: 10\n",
"에피소드:4 점수: 36\n",
"에피소드:5 점수: 13\n",
"에피소드:6 점수: 17\n",
"에피소드:7 점수: 9\n",
"에피소드:8 점수: 13\n",
"에피소드:9 점수: 11\n",
"에피소드:10 점수: 28\n",
"에피소드:11 점수: 11\n",
"에피소드:12 점수: 12\n",
"에피소드:1 점수: 21\n",
"에피소드:2 점수: 28\n",
"에피소드:3 점수: 21\n",
"에피소드:4 점수: 51\n",
"에피소드:5 점수: 12\n",
"에피소드:6 점수: 20\n",
"에피소드:7 점수: 8\n",
"에피소드:8 점수: 9\n",
"에피소드:9 점수: 10\n",
"에피소드:10 점수: 12\n",
"에피소드:11 점수: 14\n",
"에피소드:12 점수: 11\n",
"에피소드:13 점수: 9\n",
"에피소드:14 점수: 20\n",
"에피소드:15 점수: 11\n",
"에피소드:16 점수: 11\n",
"에피소드:17 점수: 10\n",
"에피소드:18 점수: 15\n",
"에피소드:19 점수: 13\n",
"에피소드:20 점수: 11\n",
"에피소드:21 점수: 13\n",
"에피소드:22 점수: 22\n",
"에피소드:23 점수: 26\n",
"에피소드:24 점수: 59\n",
"에피소드:14 점수: 9\n",
"에피소드:15 점수: 10\n",
"에피소드:16 점수: 26\n",
"에피소드:17 점수: 11\n",
"에피소드:18 점수: 9\n",
"에피소드:19 점수: 11\n",
"에피소드:20 점수: 25\n",
"에피소드:21 점수: 12\n",
"에피소드:22 점수: 19\n",
"에피소드:23 점수: 12\n",
"에피소드:24 점수: 27\n",
"에피소드:25 점수: 30\n",
"에피소드:26 점수: 22\n",
"에피소드:27 점수: 26\n",
"에피소드:28 점수: 25\n",
"에피소드:29 점수: 57\n",
"에피소드:30 점수: 83\n",
"에피소드:31 점수: 62\n",
"에피소드:32 점수: 45\n",
"에피소드:33 점수: 62\n",
"에피소드:34 점수: 80\n",
"에피소드:35 점수: 88\n",
"에피소드:36 점수: 57\n",
"에피소드:37 점수: 52\n",
"에피소드:38 점수: 45\n",
"에피소드:39 점수: 49\n",
"에피소드:40 점수: 63\n",
"에피소드:41 점수: 61\n",
"에피소드:42 점수: 75\n",
"에피소드:43 점수: 52\n",
"에피소드:44 점수: 81\n",
"에피소드:45 점수: 98\n",
"에피소드:46 점수: 129\n",
"에피소드:47 점수: 153\n",
"에피소드:48 점수: 169\n",
"에피소드:49 점수: 120\n",
"에피소드:50 점수: 144\n"
"에피소드:26 점수: 66\n",
"에피소드:27 점수: 24\n",
"에피소드:28 점수: 45\n",
"에피소드:29 점수: 47\n",
"에피소드:30 점수: 35\n",
"에피소드:31 점수: 35\n",
"에피소드:32 점수: 40\n",
"에피소드:33 점수: 44\n",
"에피소드:34 점수: 34\n",
"에피소드:35 점수: 57\n",
"에피소드:36 점수: 52\n",
"에피소드:37 점수: 70\n",
"에피소드:38 점수: 124\n",
"에피소드:39 점수: 118\n",
"에피소드:40 점수: 33\n",
"에피소드:41 점수: 128\n",
"에피소드:42 점수: 55\n",
"에피소드:43 점수: 178\n",
"에피소드:44 점수: 88\n",
"에피소드:45 점수: 103\n",
"에피소드:46 점수: 101\n",
"에피소드:47 점수: 120\n",
"에피소드:48 점수: 140\n",
"에피소드:49 점수: 113\n",
"에피소드:50 점수: 85\n"
]
}
],
Expand Down Expand Up @@ -252,19 +224,38 @@
" break"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[21, 28, 21, 51, 12, 20, 8, 9, 10, 12, 14, 11, 9, 9, 10, 26, 11, 9, 11, 25, 12, 19, 12, 27, 30, 66, 24, 45, 47, 35, 35, 40, 44, 34, 57, 52, 70, 124, 118, 33, 128, 55, 178, 88, 103, 101, 120, 140, 113, 85]\n"
]
}
],
"source": [
"print(score_history)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"import matplotlib"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "pytorch",
"language": "python",
"name": "python3"
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -276,7 +267,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.7.3"
}
},
"nbformat": 4,
Expand Down