forked from aleju/mario-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
225 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
-- Keep emulator's memory library accessible, | ||
-- we will use "memory" for the replay memory | ||
lsne_memory = memory | ||
|
||
---------------------------------- | ||
-- requires | ||
---------------------------------- | ||
require 'torch' | ||
require 'image' | ||
require 'nn' | ||
require 'optim' | ||
memory = require 'memory_lsqlite3' | ||
network = require 'network' | ||
actions = require 'actions' | ||
Action = require 'action' | ||
util = require 'util' | ||
states = require 'states_sqlite' | ||
State = require 'state' | ||
rewards = require 'rewards' | ||
Reward = require 'reward' | ||
ForgivingMSECriterion = require 'layers.ForgivingMSECriterion' | ||
ForgivingAbsCriterion = require 'layers.ForgivingAbsCriterion' | ||
ok, display = pcall(require, 'display') | ||
if not ok then print('display not found. unable to plot') end | ||
|
||
---------------------------------- | ||
-- RNG seed | ||
---------------------------------- | ||
SEED = 43 | ||
|
||
---------------------------------- | ||
-- GPU / cudnn | ||
---------------------------------- | ||
GPU = 0 | ||
require 'cutorch' | ||
require 'cunn' | ||
require 'cudnn' | ||
if GPU >= 0 then | ||
print(string.format("Using gpu device %d", GPU)) | ||
cutorch.setDevice(GPU + 1) | ||
cutorch.manualSeed(SEED) | ||
|
||
-- Saves 40% time according to http://torch.ch/blog/2016/02/04/resnets.html | ||
cudnn.fastest = true | ||
cudnn.benchmark = true | ||
end | ||
math.randomseed(SEED) | ||
torch.manualSeed(SEED) | ||
torch.setdefaulttensortype('torch.FloatTensor') | ||
-------------------------------- | ||
|
||
---------------------------------- | ||
-- Other settings | ||
---------------------------------- | ||
FPS = movie.get_game_info().fps | ||
REACT_EVERY_NTH_FRAME = 5 | ||
print(string.format("FPS: %d, Reacting every %d frames", FPS, REACT_EVERY_NTH_FRAME)) | ||
|
||
-- filepath where current game's last screenshot will be saved | ||
-- ideally on a ramdisk (for speed and less stress on the hard drive) | ||
SCREENSHOT_FILEPATH = "/media/ramdisk/mario-ai-screenshots/current-screen.png" | ||
|
||
IMG_DIMENSIONS = {1, 64, 64} -- screenshots will be resized to this immediately | ||
IMG_DIMENSIONS_Q_HISTORY = {1, 32, 32} -- size of images fed into Q (action history) | ||
IMG_DIMENSIONS_Q_LAST = {1, 64, 64} -- size of the last state's image fed into Q | ||
--IMG_DIMENSIONS_AE = {1, 128, 128} | ||
|
||
BATCH_SIZE = 16 | ||
STATES_PER_EXAMPLE = 4 -- how many states (previous + last one) to use per example fed into Q | ||
|
||
GAMMA_EXPECTED = 0.9 -- discount factor to use for future rewards anticipated by Q | ||
GAMMA_OBSERVED = 0.9 -- discount factor to use when cascading observed direct rewards backwards through time | ||
MAX_GAMMA_REWARD = 100 -- clamp future rewards to +/- this value | ||
|
||
P_EXPLORE_START = 0.8 -- starting epsilon value for epsilon greedy policy | ||
P_EXPLORE_END = 0.1 -- ending epsilon value for epsilon greedy policy | ||
P_EXPLORE_END_AT = 400000 -- when to end at P_EXPLORE_END (number of chosen actions) | ||
|
||
LAST_SAVE_STATE_LOAD = 0 -- last time (in number of actions) when the game has been reset to a saved state | ||
|
||
Q_L2_NORM = 1e-6 -- L2 parameter norm for Q | ||
Q_CLAMP = 5 -- clamp Q gradients to +/- this value | ||
|
||
---------------------------------- | ||
-- stats per training, will be saved and reloaded when training continues | ||
---------------------------------- | ||
STATS = { | ||
STATE_ID = 0, -- id of the last created state | ||
FRAME_COUNTER = 0, -- number of the last frame | ||
ACTION_COUNTER = 0, -- count of actions chosen so far | ||
CURRENT_DIRECT_REWARD_SUM = 0, -- no longer used? | ||
CURRENT_OBSERVED_GAMMA_REWARD_SUM = 0, -- no longer used? | ||
AVERAGE_REWARD_DATA = {}, -- plot datapoints of rewards per N states | ||
AVERAGE_LOSS_DATA = {}, -- plot datapoints of losses per N batches | ||
LAST_BEST_ACTION_VALUE = 0, -- no longer used? | ||
P_EXPLORE_CURRENT = P_EXPLORE_START -- current epsilon value for epsilon greedy policy | ||
} | ||
STATS.STATE_ID = memory.getMaxStateId(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
print("------------------------") | ||
print("TESTING") | ||
print("------------------------") | ||
|
||
paths.dofile('config.lua') | ||
|
||
print("Loading network...") | ||
Q = network.load() | ||
assert(Q ~= nil) | ||
|
||
-- count parameters | ||
--print("Number of parameters in AE:", network.getNumberOfParameters(MODEL_AE)) | ||
print("Number of parameters in Q:", network.getNumberOfParameters(Q)) | ||
|
||
print("Loading memory...") | ||
memory.load() | ||
|
||
print("Loading stats...") | ||
util.loadStats() | ||
STATS.P_EXPLORE_CURRENT = 0.0 | ||
|
||
print("Starting loop.") | ||
|
||
function on_frame_emulated() | ||
local lastLastState = states.getEntry(-2) | ||
local lastState = states.getEntry(-1) | ||
STATS.FRAME_COUNTER = movie.currentframe() | ||
|
||
if STATS.FRAME_COUNTER % REACT_EVERY_NTH_FRAME ~= 0 then | ||
return | ||
end | ||
|
||
STATS.ACTION_COUNTER = STATS.ACTION_COUNTER + 1 | ||
|
||
local state = State.new(nil, util.getScreenCompressed(), util.getCurrentScore(), util.getCountLifes(), util.getLevelBeatenStatus(), util.getMarioGameStatus(), util.getPlayerX(), util.getMarioImage(), util.isLevelEnding()) | ||
states.addEntry(state) -- getLastEntries() depends on this, don't move it after the next code block | ||
--print("Score:", score, "Level:", util.getLevel(), "x:", playerX, "status:", marioGameStatus, "levelBeatenStatus:", levelBeatenStatus, "count lifes:", countLifes, "Mario Image", util.getMarioImage()) | ||
|
||
-- Calculate reward | ||
local rew, bestAction, bestActionValue = rewards.statesToReward(states.getLastEntries(STATES_PER_EXAMPLE)) | ||
lastState.reward = rew | ||
--print(string.format("[Reward] R=%.2f DR=%.2f SDR=%.2f XDR=%.2f LBR=%.2f EGR=%.2f", rewards.getSumExpected(lastState.reward), rewards.getDirectReward(lastState.reward), lastState.reward.scoreDiffReward, lastState.reward.xDiffReward, lastState.reward.levelBeatenReward, lastState.reward.expectedGammaReward)) | ||
states.cascadeBackReward(lastState.reward) | ||
STATS.LAST_BEST_ACTION_VALUE = bestActionValue | ||
|
||
-- show state chain | ||
-- must happen before training as it might depend on network's current output | ||
display.image(states.stateChainsToImage({states.getLastEntries(STATES_PER_EXAMPLE)}, Q), {win=17, title="Last states"}) | ||
|
||
-- plot average rewards | ||
if STATS.ACTION_COUNTER % 1 == 0 then | ||
states.plotRewards() | ||
end | ||
|
||
-------------------- | ||
|
||
state.action = actions.chooseAction(states.getLastEntries(STATES_PER_EXAMPLE), false, bestAction) | ||
|
||
local levelEnded = state.levelBeatenStatus == 128 or state.marioGameStatus == 2 | ||
if levelEnded then | ||
print("Reloading saved gamestate...") | ||
states.clear() | ||
|
||
-- Reload save state if level was beaten or mario died | ||
util.loadRandomTrainingSaveState() | ||
LAST_SAVE_STATE_LOAD = STATS.ACTION_COUNTER | ||
else | ||
actions.endAllActions() | ||
actions.startAction(state.action) | ||
end | ||
end | ||
|
||
actions.endAllActions() | ||
util.loadRandomTrainingSaveState() | ||
util.setGameSpeedToVeryFast() | ||
states.fillWithEmptyStates() | ||
gui.repaint() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters