Skip to content

Commit

Permalink
Code updates, DSS addition, and tensorboard addition
Browse files Browse the repository at this point in the history
  • Loading branch information
dgriff777 committed Apr 3, 2023
1 parent eb5c9b9 commit 94666cc
Show file tree
Hide file tree
Showing 10 changed files with 548 additions and 443 deletions.
36 changes: 23 additions & 13 deletions README.MD
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@

[23]: https://github.com/dgriff777/a3c_continuous
*Update: Minor updates to code. Added distributed step size training functionality. Added integration to tensorboard so you can can log and create graphs of training, see graph of model, and visualize your weights and biases distributions as they update during training.

## NEWLY ADDED A3G A NEW GPU/CPU ARCHITECTURE OF A3C FOR SUBSTANTIALLY ACCELERATED TRAINING!!
# A3G A GPU/CPU ARCHITECTURE OF A3C FOR SUBSTANTIALLY ACCELERATED TRAINING


# RL A3C Pytorch

![A3C LSTM playing Breakout-v0](https://github.com/dgriff777/rl_a3c_pytorch/blob/master/demo/Breakout.gif) ![A3C LSTM playing SpaceInvadersDeterministic-v3](https://github.com/dgriff777/rl_a3c_pytorch/blob/master/demo/SpaceInvaders.gif) ![A3C LSTM playing MsPacman-v0](https://github.com/dgriff777/rl_a3c_pytorch/blob/master/demo/MsPacman.gif) ![A3C LSTM\
playing BeamRider-v0](https://github.com/dgriff777/rl_a3c_pytorch/blob/master/demo/BeamRider.gif) ![A3C LSTM playing Seaquest-v0](https://github.com/dgriff777/rl_a3c_pytorch/blob/master/demo/Seaquest.gif)

# NEWLY ADDED A3G!!
New implementation of A3C that utilizes GPU for speed increase in training. Which we can call **A3G**. A3G as opposed to other versions that try to utilize GPU with A3C algorithm, with A3G each agent has its own network maintained on GPU but shared model is on CPU and agent models are quickly converted to CPU to update shared model which allows updates to be frequent and fast by utilizing Hogwild Training and make updates to shared model asynchronously and without locks. This new method greatly increase training speed and models that use to take days to train can be trained in as fast as 10minutes for some Atari games! 10-15minutes for Breakout to start to score over 400! And 10mins to solve Pong!
# A3G
Implementation of A3C that utilizes GPU for speed increase in training. Which we can call **A3G**. A3G as opposed to other versions that try to utilize GPU with A3C algorithm, with A3G each agent has its own network maintained on GPU but shared model is on CPU and agent models are quickly converted to CPU to update shared model which allows updates to be frequent and fast by utilizing Hogwild Training and make updates to shared model asynchronously and without locks. This new method greatly increase training speed and models that use to take days to train can be trained in as fast as 10minutes for some Atari games! 10-15minutes for Breakout to start to score over 400! And 10mins to solve Pong!

This repository includes my implementation with reinforcement learning using Asynchronous Advantage Actor-Critic (A3C) in Pytorch an algorithm from Google Deep Mind's paper "Asynchronous Methods for Deep Reinforcement Learning."

*See [a3c_continuous][23] a newly added repo of my A3C LSTM implementation for continuous action spaces which was able to solve BipedWalkerHardcore-v2 environment (average 300+ for 100 consecutive episodes)*
*See [a3c_continuous][23] a newly added repo of my A3C LSTM implementation for continuous action spaces which was able to solve BipedWalkerHardcore-v3 environment (average 300+ for 100 consecutive episodes)*


### A3C LSTM
Expand Down Expand Up @@ -61,23 +61,23 @@ link to the Gym environment evaluations below

- Python 2.7+
- Openai Gym and Universe
- Pytorch
- Pytorch (Pytorch 2.0 has a bug where it incorrectly occupies GPU memory on all GPUs being used when backward() is called on training processes. This does not slow down training but it does unnecesarily take up a lot of gpu memory. If this is problem for you and running out of gpu memory downgrade pytorch)

## Training
*When training model it is important to limit number of worker processes to number of cpu cores available as too many processes (e.g. more than one process per cpu core available) will actually be detrimental in training speed and effectiveness*

To train agent in Pong-v0 environment with 32 different worker processes:
To train agent in PongNoFrameskip-v4 environment with 32 different worker processes:

```
python main.py --env Pong-v0 --workers 32
python main.py --env PongNoFrameskip-v4 --workers 32
```
#A3C-GPU
*training using machine with 4 V100 GPUs and 20core CPU for PongDeterministic-v4 took 10 minutes to converge*
#A3G-Training
*training using machine with 4 V100 GPUs and 20core CPU for PongNoFrameskip-v4 took 10 minutes to converge*

To train agent in PongDeterministic-v4 environment with 32 different worker processes on 4 GPUs with new A3G:
To train agent in PongNoFrameskip-v4 environment with 32 different worker processes on 4 GPUs with new A3G:

```
python main.py --env PongDeterministic-v4 --workers 32 --gpu-ids 0 1 2 3 --amsgrad True
python main.py --env PongNoFrameskip-v4 --workers 32 --gpu-ids 0 1 2 3 --amsgrad
```


Expand All @@ -88,8 +88,18 @@ Hit Ctrl C to end training session properly
## Evaluation
To run a 100 episode gym evaluation with trained model
```
python gym_eval.py --env Pong-v0 --num-episodes 100
python gym_eval.py --env PongNoFrameskip-v4 --num-episodes 100 --new-gym-eval
```

## Distributed Step Size training
Example of use to train an agent using different step sizes across training processes from provided list of step sizes
```
python main.py --env PongNoFrameskip-v4 --workers 18 --gpu-ids 0 1 2 --amsgrad --distributed-step-size 16 32 64 --tau 0.92
```
Below a graph showing of running the distributed step size training command above
![PongNoFrameskip DSS Training](https://github.com/dgriff777/rl_a3c_pytorch/blob/master/demo/Pong_dss_training.png)


*Notice BeamRiderNoFrameskip-v4 reaches scores over 50,000 in less than 2hrs of training compared to the gym v0 version this shows the difficulty of those versions but also the timelimit being a major factor in score level*

*These training charts were done on a DGX Station using 4GPUs and 20core Cpu. I used 36 worker agents and a tau of 0.92 which is the lambda in Generalized Advantage Estimation equation to introduce more variance due to the more deterministic nature of using just a 4 frame skip environment and a 0-30 NoOp start*
Expand Down
Binary file added demo/Pong_dss_training.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 11 additions & 10 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import deque
from gym.spaces.box import Box
#from skimage.color import rgb2gray
from cv2 import resize
from cv2 import resize, INTER_AREA
#from skimage.transform import resize
#from scipy.misc import imresize as resize
import random
Expand All @@ -30,12 +30,13 @@ def atari_env(env_id, env_conf, args):

def process_frame(frame, conf):
frame = frame[conf["crop1"]:conf["crop2"] + 160, :160]
frame = frame.mean(2)
frame = frame.astype(np.float32)
frame *= (1.0 / 255.0)
frame = resize(frame, (80, conf["dimension2"]))
frame = resize(frame, (80, 80))
frame = np.reshape(frame, [1, 80, 80])
# frame = frame.mean(2)
# frame = frame.astype(np.float32)
# frame *= (1.0 / 255.0)
# frame = resize(frame, (80, conf["dimension2"]), interpolation=INTER_AREA)
frame = resize(frame, (80, 80), interpolation=INTER_AREA)
frame = (0.2989 * frame[:,:,0] + 0.587 * frame[:,:,1] + 0.114 * frame[:,:,2])
frame = np.reshape(frame, [1, 80, 80]).astype(np.float32)
return frame


Expand Down Expand Up @@ -64,8 +65,8 @@ def observation(self, observation):
self.state_std = self.state_std * self.alpha + \
observation.std() * (1 - self.alpha)

unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps))
unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps))
unbiased_mean = self.state_mean / (1 - (self.alpha**self.num_steps))
unbiased_std = self.state_std / (1 - (self.alpha**self.num_steps))

return (observation - unbiased_mean) / (unbiased_std + 1e-8)

Expand Down Expand Up @@ -142,7 +143,7 @@ def step(self, action):
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, self.was_real_done
return obs, reward, done, info

def reset(self, **kwargs):
"""Reset only when lives are exhausted.
Expand Down
192 changes: 99 additions & 93 deletions gym_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,72 +10,79 @@
import gym
import logging
import time
#from gym.configuration import undo_logger_setup

#undo_logger_setup()
parser = argparse.ArgumentParser(description='A3C_EVAL')
gym.logger.set_level(40)

parser = argparse.ArgumentParser(description="A3C_EVAL")
parser.add_argument(
'--env',
default='Pong-v0',
metavar='ENV',
help='environment to train on (default: Pong-v0)')
"-ev",
"--env",
default="PongNoFrameskip-v4",
help="environment to train on (default: PongNoFrameskip-v4)",
)
parser.add_argument(
'--env-config',
default='config.json',
metavar='EC',
help='environment to crop and resize info (default: config.json)')
"-evc", "--env-config",
default="config.json",
help="environment to crop and resize info (default: config.json)")
parser.add_argument(
'--num-episodes',
"-ne",
"--num-episodes",
type=int,
default=100,
metavar='NE',
help='how many episodes in evaluation (default: 100)')
parser.add_argument(
'--load-model-dir',
default='trained_models/',
metavar='LMD',
help='folder to load trained models from')
help="how many episodes in evaluation (default: 100)",
)
parser.add_argument(
'--log-dir', default='logs/', metavar='LG', help='folder to save logs')
"-lmd",
"--load-model-dir",
default="trained_models/",
help="folder to load trained models from",
)
parser.add_argument("-lgd", "--log-dir", default="logs/", help="folder to save logs")
parser.add_argument(
'--render',
default=False,
metavar='R',
help='Watch game as it being played')
"-r", "--render", action="store_true", help="Watch game as it being played"
)
parser.add_argument(
'--render-freq',
"-rf",
"--render-freq",
type=int,
default=1,
metavar='RF',
help='Frequency to watch rendered game play')
help="Frequency to watch rendered game play",
)
parser.add_argument(
'--max-episode-length',
"-mel",
"--max-episode-length",
type=int,
default=10000,
metavar='M',
help='maximum length of an episode (default: 100000)')
help="maximum length of an episode (default: 100000)",
)
parser.add_argument(
"-nge",
"--new-gym-eval",
action="store_true",
help="Create a gym evaluation for upload",
)
parser.add_argument(
"-s", "--seed", type=int, default=1, help="random seed (default: 1)"
)
parser.add_argument(
'--gpu-id',
"-gid",
"--gpu-id",
type=int,
default=-1,
help='GPU to use [-1 CPU only] (default: -1)')
help="GPU to use [-1 CPU only] (default: -1)",
)
parser.add_argument(
'--skip-rate',
"-hs",
"--hidden-size",
type=int,
default=4,
metavar='SR',
help='frame skip rate (default: 4)')
default=512,
help="LSTM Cell number of features in the hidden state h",
)
parser.add_argument(
'--seed',
"-sk", "--skip-rate",
type=int,
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument(
'--new-gym-eval',
default=False,
metavar='NGE',
help='Create a gym evaluation for upload')
default=4,
help="frame skip rate (default: 4)")
args = parser.parse_args()

setup_json = read_config(args.env_config)
Expand All @@ -84,40 +91,38 @@
if i in args.env:
env_conf = setup_json[i]

saved_state = torch.load(
f"{args.load_model_dir}{args.env}.dat", map_location=lambda storage, loc: storage
)


setup_logger(f"{args.env}_mon_log", rf"{args.log_dir}{args.env}_mon_log")
log = logging.getLogger(f"{args.env}_mon_log")

gpu_id = args.gpu_id

torch.manual_seed(args.seed)
if gpu_id >= 0:
torch.cuda.manual_seed(args.seed)

saved_state = torch.load(
'{0}{1}.dat'.format(args.load_model_dir, args.env),
map_location=lambda storage, loc: storage)

log = {}
setup_logger('{}_mon_log'.format(args.env), r'{0}{1}_mon_log'.format(
args.log_dir, args.env))
log['{}_mon_log'.format(args.env)] = logging.getLogger('{}_mon_log'.format(
args.env))

d_args = vars(args)
for k in d_args.keys():
log['{}_mon_log'.format(args.env)].info('{0}: {1}'.format(k, d_args[k]))
log.info(f"{k}: {d_args[k]}")

env = atari_env("{}".format(args.env), env_conf, args)
env = atari_env(f"{args.env}", env_conf, args)
num_tests = 0
start_time = time.time()
reward_total_sum = 0
player = Agent(None, env, args, None)
player.model = A3Clstm(player.env.observation_space.shape[0],
player.env.action_space)
player.model = A3Clstm(player.env.observation_space.shape[0], player.env.action_space, args)
player.gpu_id = gpu_id
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
player.model = player.model.cuda()
if args.new_gym_eval:
player.env = gym.wrappers.Monitor(
player.env, "{}_monitor".format(args.env), force=True)
player.env, f"{args.env}_monitor", force=True)

if gpu_id >= 0:
with torch.cuda.device(gpu_id):
Expand All @@ -126,38 +131,39 @@
player.model.load_state_dict(saved_state)

player.model.eval()
for i_episode in range(args.num_episodes):
player.state = player.env.reset()
player.state = torch.from_numpy(player.state).float()
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
player.state = player.state.cuda()
player.eps_len += 2
reward_sum = 0
while True:
if args.render:
if i_episode % args.render_freq == 0:
player.env.render()

player.action_test()
reward_sum += player.reward
try:
for i_episode in range(args.num_episodes):
player.state = player.env.reset()
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
player.state = torch.from_numpy(player.state).float().cuda()
else:
player.state = torch.from_numpy(player.state).float()
player.eps_len = 0
reward_sum = 0
while 1:
if args.render:
if i_episode % args.render_freq == 0:
player.env.render()
player.action_test()
reward_sum += player.reward
if player.done and not player.env.was_real_done:
state = player.env.reset()
player.state = torch.from_numpy(state).float()
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
player.state = player.state.cuda()
elif player.env.was_real_done:
num_tests += 1
reward_total_sum += reward_sum
reward_mean = reward_total_sum / num_tests
log.info(
f"Time {time.strftime('%Hh %Mm %Ss', time.gmtime(time.time() - start_time))}, episode reward {reward_sum}, episode length {player.eps_len}, reward mean {reward_mean:.4f}"
)
break
except KeyboardInterrupt:
print("KeyboardInterrupt exception is caught")
finally:
print("gym evalualtion process finished")

if player.done and not player.info:
state = player.env.reset()
player.eps_len += 2
player.state = torch.from_numpy(state).float()
if gpu_id >= 0:
with torch.cuda.device(gpu_id):
player.state = player.state.cuda()
elif player.info:
num_tests += 1
reward_total_sum += reward_sum
reward_mean = reward_total_sum / num_tests
log['{}_mon_log'.format(args.env)].info(
"Time {0}, episode reward {1}, episode length {2}, reward mean {3:.4f}".
format(
time.strftime("%Hh %Mm %Ss",
time.gmtime(time.time() - start_time)),
reward_sum, player.eps_len, reward_mean))
player.eps_len = 0
break
player.env.close()
Loading

0 comments on commit 94666cc

Please sign in to comment.