Skip to content

Commit

Permalink
Added minimum and maximum reward restrictions and some safeguards aga…
Browse files Browse the repository at this point in the history
…inst sampling non-realistic game states.
  • Loading branch information
tambetm committed Sep 25, 2015
1 parent d6bb455 commit 0f0800f
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 7 deletions.
2 changes: 1 addition & 1 deletion play.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env bash

python src/main.py --replay_size 5 --play_games 1 --display_screen $*
python src/main.py --replay_size 100 --play_games 1 --display_screen $*
4 changes: 4 additions & 0 deletions profile_random.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/usr/bin/env bash

# never explore, always predict
python -m cProfile -s cumtime $* src/main.py --replay_size 100 --random_steps=1000 --train_steps=0 --test_steps=0 --epochs=1 roms/pong.bin
2 changes: 1 addition & 1 deletion profile_test.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env bash

# never explore, always predict
python -m cProfile -s cumtime $* src/main.py --replay_size 5 --exploration_rate_test=0 --random_steps=0 --train_steps=0 --test_steps=1000 --epochs=1 roms/pong.bin
python -m cProfile -s cumtime $* src/main.py --replay_size 100 --exploration_rate_test=0 --random_steps=0 --train_steps=0 --test_steps=1000 --epochs=1 roms/pong.bin
2 changes: 1 addition & 1 deletion profile_train.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env bash

# predict all moves by random, to separate prediction and training time
python -m cProfile -s cumtime $* src/main.py --exploration_rate_end=1 --replay_size 5 --random_steps=5 --train_steps=1000 --test_steps=0 --epochs=1 --train_frequency 1 --target_steps 0 roms/pong.bin
python -m cProfile -s cumtime $* src/main.py --replay_size 100 --exploration_rate_end=1 --random_steps=5 --train_steps=1000 --test_steps=0 --epochs=1 --train_frequency 1 --target_steps 0 roms/pong.bin
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
memarg = parser.add_argument_group('Replay memory')
memarg.add_argument("--replay_size", type=int, default=1000000, help="Maximum size of replay memory.")
memarg.add_argument("--history_length", type=int, default=4, help="How many screen frames form a state.")
memarg.add_argument("--min_reward", type=float, default=-1, help="Minimum reward.")
memarg.add_argument("--max_reward", type=float, default=1, help="Maximum reward.")

netarg = parser.add_argument_group('Deep Q-learning network')
netarg.add_argument("--learning_rate", type=float, default=0.00025, help="Learning rate.")
Expand Down
24 changes: 21 additions & 3 deletions src/replay_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def __init__(self, args):
self.history_length = args.history_length
self.dims = (args.screen_height, args.screen_width)
self.batch_size = args.batch_size
self.min_reward = args.min_reward
self.max_reward = args.max_reward
self.count = 0
self.current = 0

Expand All @@ -27,6 +29,15 @@ def add(self, action, reward, screen, terminal):
assert screen.shape == self.dims
# NB! screen is post-state, after action and reward
self.actions[self.current] = action
# clip reward between -1 and 1
if self.min_reward and reward < self.min_reward:
#logger.debug("Smaller than min_reward: %d" % reward)
reward = max(reward, self.min_reward)
#logger.info("After clipping: %d" % reward)
if self.max_reward and reward > self.max_reward:
#logger.debug("Bigger than max_reward: %d" % reward)
reward = min(reward, self.max_reward)
#logger.info("After clipping: %d" % reward)
self.rewards[self.current] = reward
self.screens[self.current, ...] = screen
self.terminals[self.current] = terminal
Expand Down Expand Up @@ -61,10 +72,17 @@ def getMinibatch(self):
while len(indexes) < self.batch_size:
# find random index
while True:
# sample one index (ignore states wraping over
index = random.randint(self.history_length, self.count - 1)
# if does not wrap over episode end
if not self.terminals[(index - self.history_length):index].any():
break
# if wraps over current pointer, then get new one
if index >= self.current and index - self.history_length < self.current:
continue
# if wraps over episode end, then get new one
# NB! poststate (last screen) can be terminal state!
if self.terminals[(index - self.history_length):index].any():
continue
# otherwise use this index
break

# NB! having index first is fastest in C-order matrices
self.prestates[len(indexes), ...] = self.getState(index - 1)
Expand Down
2 changes: 1 addition & 1 deletion test.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env bash

python src/main.py --replay_size 5 --random_steps 0 --train_steps 0 --epochs 1 $*
python src/main.py --replay_size 100 --random_steps 0 --train_steps 0 --epochs 1 $*

0 comments on commit 0f0800f

Please sign in to comment.