Skip to content

Commit ceef522

Browse files
Kaixhinapaszke
authored andcommitted
Expanded DQN tutorial (pytorch#4)
1 parent a2071de commit ceef522

File tree

1 file changed

+33
-23
lines changed

1 file changed

+33
-23
lines changed

Reinforcement (Q-)Learning with PyTorch.ipynb

+33-23
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,19 @@
66
"source": [
77
"# PyTorch DQN tutorial\n",
88
"\n",
9-
"This tutorial shows how to use pytorch to train an DQN agent on a CartPole-v0 task from Open AI gym.\n",
9+
"This tutorial shows how to use PyTorch to train a DQN agent on the CartPole-v0 task from the [OpenAI Gym](https://gym.openai.com/).\n",
1010
"\n",
1111
"### Task\n",
1212
"\n",
13-
"The agent has to decide to move the cart left or right, so that the pole attached to it stays upright. You can find an official board with various algorithms and visualizations [at the AI Gym website](https://gym.openai.com/envs/CartPole-v0).\n",
13+
"The agent has to decide between two actions - moving the cart left or right - so that the pole attached to it stays upright. You can find an official leaderboard with various algorithms and visualizations at the [Gym website](https://gym.openai.com/envs/CartPole-v0).\n",
1414
"\n",
1515
"![cartpole](images/cartpole.gif)\n",
1616
"\n",
17-
"This task is designed so that the input are 4 real values representing the environment state (accelerations, etc.). However this seems a bit boring, so we'll use a screen patch centered on the cart as an input. Because of this, our results aren't directly comparabe to the ones from an official leaderboard - our task is harder.\n",
17+
"As the agent observes the current state of the environment and chooses an action, the environment *transitions* to a new state, and also returns a reward that indicates the consequences of the action. In this task, the environment terminates if the pole falls over too far.\n",
1818
"\n",
19-
"This unfortunately slows down the training, because we have to render all the frames."
19+
"The CartPole task is designed so that the inputs to the agent are 4 real values representing the environment state (position, velocity, etc.). However, neural networks can solve the task purely by looking at the scene, so we'll use a patch of the screen centered on the cart as an input. Because of this, our results aren't directly comparable to the ones from the official leaderboard - our task is much harder. Unfortunately this does slow down the training, because we have to render all the frames.\n",
20+
"\n",
21+
"Strictly speaking, we will present the state as the difference between the current screen patch and the previous one. This will allow the agent to take the velocity of the pole into account from one image."
2022
]
2123
},
2224
{
@@ -25,12 +27,12 @@
2527
"source": [
2628
"### Packages\n",
2729
"\n",
28-
"First, let's import needed packages. From PyTorch, we'll use:\n",
30+
"First, let's import needed packages. Firstly, we need [`gym`](https://gym.openai.com/docs) for the environment. We'll also use the following from PyTorch:\n",
2931
"\n",
30-
"* neural network package (`torch.nn`)\n",
31-
"* optimization package (`torch.optim`)\n",
32-
"* automatic differentiation package (`torch.autograd`)\n",
33-
"* package with utilities for vision tasks (`torch_vision`)."
32+
"* neural networks (`torch.nn`)\n",
33+
"* optimization (`torch.optim`)\n",
34+
"* automatic differentiation (`torch.autograd`)\n",
35+
"* utilities for vision tasks (`torchvision` - [a separate package](https://github.com/pytorch/vision))."
3436
]
3537
},
3638
{
@@ -76,9 +78,9 @@
7678
"cell_type": "markdown",
7779
"metadata": {},
7880
"source": [
79-
"### Replay memory\n",
81+
"### Replay Memory\n",
8082
"\n",
81-
"We'll be using experience replay for training our DQN. It allows us to reuse the data we observed earlier and sample from it randomly, so the transitions that build up a batch are decorelated. It has been shown that this greately stabilizes and improves the DQN training procedure.\n",
83+
"We'll be using experience replay memory for training our DQN. It stores the transitions that the agent observes, allowing us to reuse this data later. By sampling from it randomly, the transitions that build up a batch are decorrelated. It has been shown that this greatly stabilizes and improves the DQN training procedure.\n",
8284
"\n",
8385
"For this, we're going to need two classses:\n",
8486
"\n",
@@ -125,26 +127,34 @@
125127
"\n",
126128
"### DQN algorithm\n",
127129
"\n",
128-
"Our world is deterministic, so all equations presented here are also assuming determinism of the process.\n",
130+
"Our environment is deterministic, so all equations presented here are also formulated deterministically for the sake of simplicity. In the reinforcement learning literature, they would also contain expectations over stochastic transitions in the environment.\n",
129131
"\n",
130-
"Our aim will be to train a policy that tries to maximize the discounted reward $R_{t_0} = \\sum_{t=t_0}^{\\infty} r_t \\gamma^{t - t_0}$. $\\gamma$ should be a constant between $0$ and $1$ that ensures the sum converges, and makes the rewards from the uncertain far future be less important for our agent than the ones it can be fairly confident about.\n",
132+
"Our aim will be to train a policy that tries to maximize the discounted, cumulative reward $R_{t_0} = \\sum_{t=t_0}^{\\infty} \\gamma^{t - t_0} r_t$, where $R_{t_0}$ is also known as the *return*. The discount, $\\gamma$, should be a constant between $0$ and $1$ that ensures the sum converges. It makes rewards from the uncertain far future less important for our agent than the ones in the near future that it can be fairly confident about.\n",
131133
"\n",
132-
"The main idea behind Q-learning is that if we had a function $Q^*: State \\times Action \\rightarrow \\mathbb{R}$, that could tell us what would our discounted reward be, if we were to take an action in a given state, we could easily construct a policy that miximizes our rewards:\n",
134+
"The main idea behind Q-learning is that if we had a function $Q^*: State \\times Action \\rightarrow \\mathbb{R}$, that could tell us what our return would be, if we were to take an action in a given state, then we could easily construct a policy that maximizes our rewards:\n",
133135
"\n",
134-
"$$\\pi^*(s) = \\mathrm{argmax}_a \\ Q^*(s, a)$$\n",
136+
"$$\\pi^*(s) = \\arg\\!\\max_a \\ Q^*(s, a)$$\n",
135137
"\n",
136-
"However, we don't know everything about the world, so we don't have access to $Q^*$, but since neural networks are universal function approximators, we can simply create one and train it to resemble the $Q^*$.\n",
138+
"However, we don't know everything about the world, so we don't have access to $Q^*$. But, since neural networks are universal function approximators, we can simply create one and train it to resemble $Q^*$.\n",
137139
"\n",
138-
"For our training update rule, we'll use a fact that every $Q$ function for some policy obeys the Bellman equation.\n",
140+
"For our training update rule, we'll use a fact that every $Q$ function for some policy obeys the Bellman equation:\n",
139141
"\n",
140142
"$$Q^{\\pi}(s, a) = r + \\gamma Q^{\\pi}(s', \\pi(s'))$$\n",
141143
"\n",
142-
"Our loss will be a mean squared error between the two sides of the equality (where $B$ is a batch of transitions):\n",
143-
"$$L = \\frac{1}{|B|}\\sum_{(s, a, s', r) \\ \\in \\ B} (Q(s, a) - (r + \\gamma \\max_a Q(s', a)))^2$$\n",
144+
"The difference between the two sides of the equality is known as the temporal difference error, $\\delta$:\n",
145+
"\n",
146+
"$$\\delta = Q(s, a) - (r + \\gamma \\max_a Q(s', a))$$\n",
147+
"\n",
148+
"To minimise this error, we will use the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss). The Huber loss acts like the mean squared error when the error is small, but like the mean absolute error when the error is large - this makes it more robust to outliers when the estimates of $Q$ are very noisy. We calculate this over a batch of transitions, $B$, sampled from the replay memory:\n",
149+
"\n",
150+
"$$\\mathcal{L} = \\frac{1}{|B|}\\sum_{(s, a, s', r) \\ \\in \\ B} \\mathcal{L}(\\delta) \\quad \\text{where} \\quad \\mathcal{L}(\\delta) = \\begin{cases}\n",
151+
" \\frac{1}{2}{\\delta^2} & \\text{for } |\\delta| \\le 1, \\\\\n",
152+
" |\\delta| - \\frac{1}{2} & \\text{otherwise.}\n",
153+
"\\end{cases}$$\n",
144154
"\n",
145155
"### Q-network\n",
146156
"\n",
147-
"Our model will be a CNN that takes in a difference between the current screen patch, and the previous one. This will allow it to take the velocity of the pole into account. It has two outputs representing $Q(s, \\mathrm{left})$ and $Q(s, \\mathrm{right})$ (where $s$ is the input to the network)."
157+
"Our model will be a convolutional neural network that takes in the difference between the current and previous screen patches. It has two outputs, representing $Q(s, \\mathrm{left})$ and $Q(s, \\mathrm{right})$ (where $s$ is the input to the network). In effect, the network is trying to predict the *quality* of taking each action given the current input."
148158
]
149159
},
150160
{
@@ -179,7 +189,7 @@
179189
"source": [
180190
"### Input extraction\n",
181191
"\n",
182-
"The code below are utilities for extracting and processing rendered images from the env. It uses the `torch_vision` package, that makes it easy to compose image transforms. Once you run the cell it will display an example patch it extracted."
192+
"The code below are utilities for extracting and processing rendered images from the environment. It uses the `torchvision` package, which makes it easy to compose image transforms. Once you run the cell it will display an example patch that it extracted."
183193
]
184194
},
185195
{
@@ -303,9 +313,9 @@
303313
"\n",
304314
"Finally, the code for training our model.\n",
305315
"\n",
306-
"At the top you can find an `optimize_model` function that performs a single step of the optimization. It first samples a batch, concatenates all the tensors into a single one, computes $Q(s_t, a_t)$ and $V(s_{t+1}) = \\max_a Q(s_{t+1}, a)$, and combines them into our loss. There's some complication because of the final states, for which $V(s) = 0$.\n",
316+
"At the top you can find an `optimize_model` function that performs a single step of the optimization. It first samples a batch, concatenates all the tensors into a single one, computes $Q(s_t, a_t)$ and $V(s_{t+1}) = \\max_a Q(s_{t+1}, a)$, and combines them into our loss. By defition we set $V(s) = 0$ if $s$ is a terminal state.\n",
307317
"\n",
308-
"Below, you can find the main training loop. At the beginning we reset the env and initialize the `state` variable. Then, we sample an action, execute it, observe the next screen and the reward (always 1), and optimize our model once. When the episode ends (our model fails), we restart the loop."
318+
"Below, you can find the main training loop. At the beginning we reset the environment and initialize the `state` variable. Then, we sample an action, execute it, observe the next screen and the reward (always 1), and optimize our model once. When the episode ends (our model fails), we restart the loop."
309319
]
310320
},
311321
{

0 commit comments

Comments
 (0)