Skip to content

Commit

Permalink
Add dqn_atari.py documentation (vwxyzjn#124)
Browse files Browse the repository at this point in the history
* Refactor DQN documentation

* Fix typo

* Add docs

* Add current notes

* Update docs

* Add reproduce scripts

* use new api

* Fix experiment script

* Update DQN docs

* Update documentation

* Update docs

* Update DQN documentation

* Add explanation of metrics
  • Loading branch information
vwxyzjn authored Mar 24, 2022
1 parent cfed3dd commit 6febe20
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 14 deletions.
14 changes: 14 additions & 0 deletions benchmark/dqn/atari.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# PongNoFrameskip-v4
poetry run python cleanrl/dqn_atari.py --env-id PongNoFrameskip-v4 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn_atari.py --env-id PongNoFrameskip-v4 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn_atari.py --env-id PongNoFrameskip-v4 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark

# BeamRiderNoFrameskip-v4
poetry run python cleanrl/dqn_atari.py --env-id BeamRiderNoFrameskip-v4 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn_atari.py --env-id BeamRiderNoFrameskip-v4 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn_atari.py --env-id BeamRiderNoFrameskip-v4 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark

# BreakoutNoFrameskip-v4
poetry run python cleanrl/dqn_atari.py --env-id BreakoutNoFrameskip-v4 --track --capture-video --seed 1 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn_atari.py --env-id BreakoutNoFrameskip-v4 --track --capture-video --seed 2 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
poetry run python cleanrl/dqn_atari.py --env-id BreakoutNoFrameskip-v4 --track --capture-video --seed 3 --wandb-project-name cleanrl --wandb-entity openrlbenchmark
2 changes: 1 addition & 1 deletion cleanrl/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def parse_args():
help="the batch size of sample from the reply memory")
parser.add_argument("--start-e", type=float, default=1,
help="the starting epsilon for exploration")
parser.add_argument("--end-e", type=float, default=0.02,
parser.add_argument("--end-e", type=float, default=0.01,
help="the ending epsilon for exploration")
parser.add_argument("--exploration-fraction", type=float, default=0.10,
help="the fraction of `total-timesteps` it takes from start-e to go end-e")
Expand Down
133 changes: 122 additions & 11 deletions docs/rl-algorithms/dqn.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,133 @@
# Deep Q-Learning (DQN)

## Overview

As an extension of the Q-learning, DQN's main technical contribution is the use of replay buffer and target network, both of which would help improve the stability of the algorithm.


Original papers:

* [Playing Atari with Deep Reinforcement Learning
](https://arxiv.org/abs/1312.5602)
* [Human-level control through deep reinforcement learning
](https://www.nature.com/articles/nature14236)

Our single-file implementations of DQN:
## Implemented Variants


| Variants Implemented | Description |
| ----------- | ----------- |
| :material-github: [`dqn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py), :material-file-document: [docs](/rl-algorithms/dqn/#dqnpy) | For classic control tasks like `CartPole-v1`. |
| :material-github: [`dqn_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py), :material-file-document: [docs](/rl-algorithms/dqn/#dqn_ataripy) | For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques. |

Below are our single-file implementations of DQN:

## `dqn.py`

The [dqn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py) has the following features:

* Works with the `Box` observation space of low-level features
* Works with the `Discrete` action space
* Works with envs like `CartPole-v1`

### Implementation details

[dqn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py) includes the 11 core implementation details:



## `dqn_atari.py`

The [dqn_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py) has the following features:

* For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques.
* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)`
* Works with the `Discrete` action space

### Usage

```bash
poetry install -E atari
python cleanrl/dqn_atari.py --env-id BreakoutNoFrameskip-v4
python cleanrl/dqn_atari.py --env-id PongNoFrameskip-v4
```


### Explanation of the logged metrics

Running `python cleanrl/dqn_atari.py` will automatically record various metrics such as actor or value losses in Tensorboard. Below is the documentation for these metrics:

* `charts/episodic_return`: episodic return of the game
* `charts/SPS`: number of steps per second
* `losses/td_loss`: the mean squared error (MSE) between the Q values at timestep $t$ and the Bellman update target estimated using the reward $r_t$ and the Q values at timestep $t+1$, thus minimizing the *one-step* temporal difference. Formally, it can be expressed by the equation below.
$$
J(\theta^{Q}) = \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \big[ (Q(s, a) - y)^2 \big],
$$
with the Bellman update target is $y = r + \gamma \, Q^{'}(s', a')$ and the replay buffer is $\mathcal{D}$.
* `losses/q_values`: implemented as `qf1(data.observations, data.actions).view(-1)`, it is the average Q values of the sampled data in the replay buffer; useful when gauging if under or over estimation happens.


### Implementation details

[dqn_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py) is based on (Mnih et al., 2015)[^1] but presents a few implementation differences:

1. `dqn_atari.py` use slightly different hyperparameters. Specifically,
- `dqn_atari.py` uses the more popular Adam Optimizer with the `--learning-rate=1e-4` as follows:
```python
optim.Adam(q_network.parameters(), lr=1e-4)
```
whereas (Mnih et al., 2015)[^1] (Exntended Data Table 1) uses the RMSProp optimizer with `--learning-rate=2.5e-4`, gradient momentum `0.95`, squared gradient momentum `0.95`, and min squared gradient `0.01` as follows:
```python
optim.RMSprop(
q_network.parameters(),
lr=2.5e-4,
momentum=0.95,
# ... PyTorch's RMSprop does not directly support
# squared gradient momentum and min squared gradient
# so we are not sure what to put here.
)
```
- `dqn_atari.py` uses `--learning-starts=80000` whereas (Mnih et al., 2015)[^1] (Exntended Data Table 1) uses `--learning-starts=50000`.
- `dqn_atari.py` uses `--target-network-frequency=1000` whereas (Mnih et al., 2015)[^1] (Exntended Data Table 1) uses `--learning-starts=10000`.
- `dqn_atari.py` uses `--total-timesteps=10000000` (i.e., 10M timesteps = 40M frames because of frame-skipping) whereas (Mnih et al., 2015)[^1] uses `--total-timesteps=50000000` (i.e., 50M timesteps = 200M frames) (See "Training details" under "METHODS" on page 6 and the related source code [run_gpu#L32](https://github.com/deepmind/dqn/blob/9d9b1d13a2b491d6ebd4d046740c511c662bbe0f/run_gpu#L32), [dqn/train_agent.lua#L81-L82](https://github.com/deepmind/dqn/blob/9d9b1d13a2b491d6ebd4d046740c511c662bbe0f/dqn/train_agent.lua#L81-L82), and [dqn/train_agent.lua#L165-L169](https://github.com/deepmind/dqn/blob/9d9b1d13a2b491d6ebd4d046740c511c662bbe0f/dqn/train_agent.lua#L165-L169)).
- `dqn_atari.py` uses `--end-e=0.01` (the final exploration epsilon) whereas (Mnih et al., 2015)[^1] (Exntended Data Table 1) uses `--end-e=0.1`.
- `dqn_atari.py` uses `--exploration-fraction=0.1` whereas (Mnih et al., 2015)[^1] (Exntended Data Table 1) uses `--exploration-fraction=0.02` (all corresponds to 250000 steps or 1M frames being the frame that epsilon is annealed to `--end-e=0.1` ).
- `dqn_atari.py` treats termination and truncation the same way due to the gym interface[^2] whereas (Mnih et al., 2015)[^1] correctly handles truncation.
1. `dqn_atari.py` use a self-contained evaluation scheme: `dqn_atari.py` reports the episodic returns obtained throughout training, whereas (Mnih et al., 2015)[^1] is trained with `--end-e=0.1` but reported episodic returns using a separate evaluation process with `--end-e=0.01` (See "Evaluation procedure" under "METHODS" on page 6).
1. `dqn_atari.py` rescales the gradient so that the norm of the parameters does not exceed `0.5` like done in PPO (:material-github: [ppo2/model.py#L102-L108](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L102-L108)).


### Experiment results

PR :material-github: [vwxyzjn/cleanrl#124](https://github.com/vwxyzjn/cleanrl/pull/124) tracks our effort to conduct experiments, and the reprodudction instructions can be found at :material-github: [vwxyzjn/cleanrl/benchmark/dqn](https://github.com/vwxyzjn/cleanrl/tree/master/benchmark/dqn).

Below are the average episodic returns for `dqn_atari.py`.


| Environment | `dqn_atari.py` 10M steps | (Mnih et al., 2015)[^1] 50M steps | (Hessel et al., 2017, Figure 5)[^3]
| ----------- | ----------- | ----------- | ---- |
| BreakoutNoFrameskip-v4 | 337.64 ± 69.47 |401.2 ± 26.9 | ~230 at 10M steps, ~300 at 50M steps
| PongNoFrameskip-v4 | 20.293 ± 0.37 | 18.9 ± 1.3 | ~20 10M steps, ~20 at 50M steps
| BeamRiderNoFrameskip-v4 | 6207.41 ± 1019.96 | 6846 ± 1619 | ~6000 10M steps, ~7000 at 50M steps


Note that we save computational time by reducing timesteps from 50M to 10M, but our `dqn_atari.py` scores the same or higher than (Mnih et al., 2015)[^1] in 10M steps.


Learning curves:

<div class="grid-container">
<img src="../dqn/BeamRiderNoFrameskip-v4.png">

<img src="../dqn/BreakoutNoFrameskip-v4.png">

<img src="../dqn/PongNoFrameskip-v4.png">
</div>


Tracked experiments and game play videos:

<iframe src="https://wandb.ai/openrlbenchmark/openrlbenchmark/reports/Atari-CleanRL-s-DQN--VmlldzoxNjk3NjYx" style="width:100%; height:500px" title="CleanRL DQN Tracked Experiments"></iframe>


* [dqn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py)
* Works with the `Box` observation space of low-level features
* Works with the `Discerete` action space
* Works with envs like `CartPole-v1`
* [dqn_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py)
* For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques.
* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)`
* Works with the `Discerete` action space
[^1]:Mnih, V., Kavukcuoglu, K., Silver, D. et al. Human-level control through deep reinforcement learning. Nature 518, 529533 (2015). https://doi.org/10.1038/nature14236
[^2]:\[Proposal\] Formal API handling of truncation vs termination. https://github.com/openai/gym/issues/2510
[^3]: Hessel, M., Modayil, J., Hasselt, H.V., Schaul, T., Ostrovski, G., Dabney, W., Horgan, D., Piot, B., Azar, M.G., & Silver, D. (2018). Rainbow: Combining Improvements in Deep Reinforcement Learning. AAAI.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rl-algorithms/dqn/PongNoFrameskip-v4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions docs/rl-algorithms/ppo.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Below are our single-file implementations of PPO:
The [ppo.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py) has the following features:

* Works with the `Box` observation space of low-level features
* Works with the `Discerete` action space
* Works with the `Discrete` action space
* Works with envs like `CartPole-v1`

### Usage
Expand Down Expand Up @@ -102,7 +102,7 @@ The [ppo_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_at

* For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques.
* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)`
* Works with the `Discerete` action space
* Works with the `Discrete` action space
* Includes the 9 Atari-specific implementation details as shown in the following video tutorial
[![PPO2](ppo/ppo-2-title.png)](https://youtu.be/05RMTj-2K_Y)

Expand Down

0 comments on commit 6febe20

Please sign in to comment.