Skip to content

Commit c67cdef

Browse files
isk03276jiseongHAN
andauthored
Add GAIL algorithm (#315)
* [IBR-2091] Add gail algorithm * [IBR-2068] Modify standard deviation of gaussian action in ppo * [IBR-2091] Improve gail algorithm * [IBR-2068] Add ppo algorithm for discrete action * [IBR-2068] Add shared backbone for actor critic * [IBR-2068] Fix gpu oom bug * [IBR-2068] Tuning hyper-parameters for ppo * [IBR-2068] Modify multi env * [IBR-2091] Modify input size of discriminator network * [IBR-2068] Modify learner for shared actor critic * [IBR-2091] Add forward_backbone and forward_head function * [IBR-2091] Change threshold for determining discriminator accuracy * [IBR-2068] Rollback ppo config * [IBR-2068] Add ppo with discrete action * [IBR-2068]Remove retain_graph option * [IBR-2097] Remove retain_graph option * [IBR-2091] Add discriminator class * [IBR-2091] Modify action embedder config * [IBR-2091] Modify/Add comments * [IBR-2091] Modify pylint * [IBR-2091] Convet action type to numpy array in select_action function * [IBR-2069] Modify hidden activation function * [IBR-2097] Modify readme * [IBR-2097] Modify readme * [IBR-2069] Modify readme file * [IBR-2097] Modfiy readme * Update version 1.1.0 to 1.2.0 * update ray to 1.3.0 Co-authored-by: Jiseong Han <wltjd802@gmail.com>
1 parent a607f07 commit c67cdef

File tree

14 files changed

+756
-6
lines changed

14 files changed

+756
-6
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ This project follows the [all-contributors](https://github.com/all-contributors/
7272
10. [Recurrent Replay DQN (R2D1)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/recurrent)
7373
11. [Distributed Pioritized Experience Replay (Ape-X)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/common/apex)
7474
12. [Policy Distillation](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/distillation)
75+
13. [Generative Adversarial Imitation Learning (GAIL)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/gail)
7576

7677
## Performance
7778

@@ -139,6 +140,14 @@ See <a href="https://app.wandb.ai/medipixel_rl/LunarLanderContinuous-v2/reports/
139140
</p>
140141
</details>
141142

143+
<details><summary><b>LunarLanderContinuous-v2: PPO, SAC, GAIL</b></summary>
144+
<p><br>
145+
See <a href="https://wandb.ai/chaehyeuk-lee/LunarLanderContinuous-v2?workspace=user-chaehyeuk-lee">W&B log</a> for more details. (The performance is measured on the commit <a href="https://github.com/medipixel/rl_algorithms/commit/922222b2e249f1f14bdf1a28c9f0f00752e49907">9e897ad</a>)
146+
147+
![lunarlandercontinuous-v2_gail](https://user-images.githubusercontent.com/23740495/130401442-8b668975-8760-4a79-b757-1c1e9a9c4e47.png)
148+
</p>
149+
</details>
150+
142151
#### Reacher-v2
143152

144153
We reproduced the performance of **DDPG**, **TD3**, and **SAC** on Reacher-v2 (Mujoco). They reach the score around -3.5 to -4.5.
@@ -313,3 +322,4 @@ To cite this repository in publications:
313322
19. [Steven Kapturowski et al., "Recurrent Experience Replay in Distributed Reinforcement Learning." in International Conference on Learning Representations https://openreview.net/forum?id=r1lyTjAqYX, 2019.](https://openreview.net/forum?id=r1lyTjAqYX)
314323
20. [Horgan et al., "Distributed Prioritized Experience Replay." in International Conference on Learning Representations, 2018](https://arxiv.org/pdf/1803.00933.pdf)
315324
21. [Simonyan et al., "Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps", 2013](https://arxiv.org/pdf/1312.6034.pdf)
325+
22. [Ho et al., "Generative adversarial imitation learning", 2016](https://arxiv.org/abs/1606.03476)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
type: "GAILPPOAgent"
2+
hyper_params:
3+
gamma: 0.99
4+
tau: 0.95
5+
batch_size: 128
6+
max_epsilon: 0.2
7+
min_epsilon: 0.2
8+
epsilon_decay_period: 1500
9+
w_value: 1.0
10+
w_entropy: 0.001
11+
gradient_clip_ac: 0.5
12+
gradient_clip_cr: 1.0
13+
epoch: 10
14+
rollout_len: 1024
15+
n_workers: 4
16+
use_clipped_value_loss: False
17+
standardize_advantage: True
18+
gail_reward_weight: 1.0
19+
demo_path: "data/lunarlander_continuous_demo.pkl"
20+
21+
learner_cfg:
22+
type: "GAILPPOLearner"
23+
backbone:
24+
actor:
25+
critic:
26+
discriminator:
27+
shared_actor_critic:
28+
head:
29+
actor:
30+
type: "GaussianDist"
31+
configs:
32+
hidden_sizes: [256, 256]
33+
output_activation: "identity"
34+
fixed_logstd: True
35+
critic:
36+
type: "MLP"
37+
configs:
38+
hidden_sizes: [256, 256]
39+
output_size: 1
40+
output_activation: "identity"
41+
discriminator:
42+
type: "MLP"
43+
configs:
44+
hidden_sizes: [256, 256]
45+
output_size: 1
46+
output_activation: "identity"
47+
aciton_embedder:
48+
type: "MLP"
49+
configs:
50+
hidden_sizes: []
51+
output_size: 16
52+
output_activation: "identity"
53+
54+
optim_cfg:
55+
lr_actor: 0.0003
56+
lr_critic: 0.001
57+
lr_discriminator: 0.0003
58+
weight_decay: 0.0
59+
discriminator_acc_threshold : 0.8

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ tqdm
1010

1111
# for distributed learning
1212
redis==3.5.3 # for ray
13-
ray==1.2.0
13+
ray==1.3.0
1414
pyzmq==20.0.0
1515
pyarrow==3.0.0
1616

rl_algorithms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from .fd.dqn_learner import DQfDLearner
2222
from .fd.sac_agent import SACfDAgent
2323
from .fd.sac_learner import SACfDLearner
24+
from .gail.agent import GAILPPOAgent
25+
from .gail.learner import GAILPPOLearner
2426
from .ppo.agent import PPOAgent
2527
from .ppo.learner import PPOLearner
2628
from .recurrent.dqn_agent import R2D1Agent
@@ -45,6 +47,7 @@
4547
"PPOAgent",
4648
"SACAgent",
4749
"TD3Agent",
50+
"GAILPPOAgent",
4851
"A2CLearner",
4952
"BCDDPGLearner",
5053
"BCSACLearner",
@@ -56,6 +59,7 @@
5659
"PPOLearner",
5760
"SACLearner",
5861
"TD3Learner",
62+
"GAILPPOLearner",
5963
"R2D1Learner",
6064
"LunarLanderContinuousHER",
6165
"ReacherHER",
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# -*- coding: utf-8 -*-
2+
"""Demo buffer for GAIL algorithm."""
3+
4+
import pickle
5+
from typing import List, Tuple
6+
7+
import numpy as np
8+
import torch
9+
10+
from rl_algorithms.common.abstract.buffer import BaseBuffer
11+
12+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13+
14+
15+
class GAILBuffer(BaseBuffer):
16+
"""Buffer to store expert states and actions.
17+
18+
Attributes:
19+
obs_buf (np.ndarray): observations
20+
acts_buf (np.ndarray): actions
21+
"""
22+
23+
def __init__(self, dataset_path: str):
24+
"""Initialize a Buffer.
25+
26+
Args:
27+
dataset_path (str): path of the demo dataset
28+
"""
29+
30+
self.obs_buf: np.ndarray = None
31+
self.acts_buf: np.ndarray = None
32+
33+
self.load_demo(dataset_path)
34+
35+
def load_demo(self, dataset_path: str):
36+
"""load demo data."""
37+
with open(dataset_path, "rb") as f:
38+
demo = list(pickle.load(f))
39+
demo = np.array(demo)
40+
self.obs_buf = np.array(list(map(np.array, demo[:, 0])))
41+
self.acts_buf = np.array(list(map(np.array, demo[:, 1])))
42+
43+
def add(self):
44+
pass
45+
46+
def sample(self, batch_size, indices: List[int] = None) -> Tuple[np.ndarray, ...]:
47+
"""Randomly sample a batch of experiences from memory."""
48+
assert 0 < batch_size < len(self)
49+
50+
if indices is None:
51+
indices = np.random.choice(len(self), size=batch_size)
52+
53+
states = self.obs_buf[indices]
54+
actions = self.acts_buf[indices]
55+
56+
return torch.Tensor(states).to(device), torch.Tensor(actions).to(device)
57+
58+
def __len__(self) -> int:
59+
"""Return the current size of internal memory."""
60+
return len(self.obs_buf)

rl_algorithms/gail/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Empty."""

0 commit comments

Comments
 (0)