Skip to content

Commit

Permalink
Add testing files
Browse files Browse the repository at this point in the history
  • Loading branch information
aleju committed May 24, 2016
1 parent 1b8b475 commit 7069ae7
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 12 deletions.
30 changes: 30 additions & 0 deletions actions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,34 @@ function actions.setJoypad2(actions)
--end
end

function chooseAction(lastStates, perfect, bestAction, pExplore)
perfect = perfect or false
pExplore = pExplore or STATS.P_EXPLORE_CURRENT
local _action, _actionValue
if not perfect and math.random() < pExplore then
if bestAction == nil or math.random() < 0.5 then
-- randomize both
_action = Action.new(util.getRandomEntry(actions.ACTIONS_ARROWS), util.getRandomEntry(actions.ACTIONS_BUTTONS))
else
-- randomize only arrow or only button
if math.random() < 0.5 then
_action = Action.new(util.getRandomEntry(actions.ACTIONS_ARROWS), bestAction.button)
else
_action = Action.new(bestAction.arrow, util.getRandomEntry(actions.ACTIONS_BUTTONS))
end
end
--print("Chossing action randomly:", _action)
else
if bestAction ~= nil then
_action = bestAction
else
-- Use network to approximate action with maximal value
_action, _actionValue = network.approximateBestAction(lastStates)
--print("Q approximated action:", _action, actions.ACTION_TO_BUTTON_NAME[_action])
end
end

return _action
end

return actions
98 changes: 98 additions & 0 deletions config.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
-- Keep emulator's memory library accessible,
-- we will use "memory" for the replay memory
lsne_memory = memory

----------------------------------
-- requires
----------------------------------
require 'torch'
require 'image'
require 'nn'
require 'optim'
memory = require 'memory_lsqlite3'
network = require 'network'
actions = require 'actions'
Action = require 'action'
util = require 'util'
states = require 'states_sqlite'
State = require 'state'
rewards = require 'rewards'
Reward = require 'reward'
ForgivingMSECriterion = require 'layers.ForgivingMSECriterion'
ForgivingAbsCriterion = require 'layers.ForgivingAbsCriterion'
ok, display = pcall(require, 'display')
if not ok then print('display not found. unable to plot') end

----------------------------------
-- RNG seed
----------------------------------
SEED = 43

----------------------------------
-- GPU / cudnn
----------------------------------
GPU = 0
require 'cutorch'
require 'cunn'
require 'cudnn'
if GPU >= 0 then
print(string.format("Using gpu device %d", GPU))
cutorch.setDevice(GPU + 1)
cutorch.manualSeed(SEED)

-- Saves 40% time according to http://torch.ch/blog/2016/02/04/resnets.html
cudnn.fastest = true
cudnn.benchmark = true
end
math.randomseed(SEED)
torch.manualSeed(SEED)
torch.setdefaulttensortype('torch.FloatTensor')
--------------------------------

----------------------------------
-- Other settings
----------------------------------
FPS = movie.get_game_info().fps
REACT_EVERY_NTH_FRAME = 5
print(string.format("FPS: %d, Reacting every %d frames", FPS, REACT_EVERY_NTH_FRAME))

-- filepath where current game's last screenshot will be saved
-- ideally on a ramdisk (for speed and less stress on the hard drive)
SCREENSHOT_FILEPATH = "/media/ramdisk/mario-ai-screenshots/current-screen.png"

IMG_DIMENSIONS = {1, 64, 64} -- screenshots will be resized to this immediately
IMG_DIMENSIONS_Q_HISTORY = {1, 32, 32} -- size of images fed into Q (action history)
IMG_DIMENSIONS_Q_LAST = {1, 64, 64} -- size of the last state's image fed into Q
--IMG_DIMENSIONS_AE = {1, 128, 128}

BATCH_SIZE = 16
STATES_PER_EXAMPLE = 4 -- how many states (previous + last one) to use per example fed into Q

GAMMA_EXPECTED = 0.9 -- discount factor to use for future rewards anticipated by Q
GAMMA_OBSERVED = 0.9 -- discount factor to use when cascading observed direct rewards backwards through time
MAX_GAMMA_REWARD = 100 -- clamp future rewards to +/- this value

P_EXPLORE_START = 0.8 -- starting epsilon value for epsilon greedy policy
P_EXPLORE_END = 0.1 -- ending epsilon value for epsilon greedy policy
P_EXPLORE_END_AT = 400000 -- when to end at P_EXPLORE_END (number of chosen actions)

LAST_SAVE_STATE_LOAD = 0 -- last time (in number of actions) when the game has been reset to a saved state

Q_L2_NORM = 1e-6 -- L2 parameter norm for Q
Q_CLAMP = 5 -- clamp Q gradients to +/- this value

----------------------------------
-- stats per training, will be saved and reloaded when training continues
----------------------------------
STATS = {
STATE_ID = 0, -- id of the last created state
FRAME_COUNTER = 0, -- number of the last frame
ACTION_COUNTER = 0, -- count of actions chosen so far
CURRENT_DIRECT_REWARD_SUM = 0, -- no longer used?
CURRENT_OBSERVED_GAMMA_REWARD_SUM = 0, -- no longer used?
AVERAGE_REWARD_DATA = {}, -- plot datapoints of rewards per N states
AVERAGE_LOSS_DATA = {}, -- plot datapoints of losses per N batches
LAST_BEST_ACTION_VALUE = 0, -- no longer used?
P_EXPLORE_CURRENT = P_EXPLORE_START -- current epsilon value for epsilon greedy policy
}
STATS.STATE_ID = memory.getMaxStateId(1)
77 changes: 77 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
print("------------------------")
print("TESTING")
print("------------------------")

paths.dofile('config.lua')

print("Loading network...")
Q = network.load()
assert(Q ~= nil)

-- count parameters
--print("Number of parameters in AE:", network.getNumberOfParameters(MODEL_AE))
print("Number of parameters in Q:", network.getNumberOfParameters(Q))

print("Loading memory...")
memory.load()

print("Loading stats...")
util.loadStats()
STATS.P_EXPLORE_CURRENT = 0.0

print("Starting loop.")

function on_frame_emulated()
local lastLastState = states.getEntry(-2)
local lastState = states.getEntry(-1)
STATS.FRAME_COUNTER = movie.currentframe()

if STATS.FRAME_COUNTER % REACT_EVERY_NTH_FRAME ~= 0 then
return
end

STATS.ACTION_COUNTER = STATS.ACTION_COUNTER + 1

local state = State.new(nil, util.getScreenCompressed(), util.getCurrentScore(), util.getCountLifes(), util.getLevelBeatenStatus(), util.getMarioGameStatus(), util.getPlayerX(), util.getMarioImage(), util.isLevelEnding())
states.addEntry(state) -- getLastEntries() depends on this, don't move it after the next code block
--print("Score:", score, "Level:", util.getLevel(), "x:", playerX, "status:", marioGameStatus, "levelBeatenStatus:", levelBeatenStatus, "count lifes:", countLifes, "Mario Image", util.getMarioImage())

-- Calculate reward
local rew, bestAction, bestActionValue = rewards.statesToReward(states.getLastEntries(STATES_PER_EXAMPLE))
lastState.reward = rew
--print(string.format("[Reward] R=%.2f DR=%.2f SDR=%.2f XDR=%.2f LBR=%.2f EGR=%.2f", rewards.getSumExpected(lastState.reward), rewards.getDirectReward(lastState.reward), lastState.reward.scoreDiffReward, lastState.reward.xDiffReward, lastState.reward.levelBeatenReward, lastState.reward.expectedGammaReward))
states.cascadeBackReward(lastState.reward)
STATS.LAST_BEST_ACTION_VALUE = bestActionValue

-- show state chain
-- must happen before training as it might depend on network's current output
display.image(states.stateChainsToImage({states.getLastEntries(STATES_PER_EXAMPLE)}, Q), {win=17, title="Last states"})

-- plot average rewards
if STATS.ACTION_COUNTER % 1 == 0 then
states.plotRewards()
end

--------------------

state.action = actions.chooseAction(states.getLastEntries(STATES_PER_EXAMPLE), false, bestAction)

local levelEnded = state.levelBeatenStatus == 128 or state.marioGameStatus == 2
if levelEnded then
print("Reloading saved gamestate...")
states.clear()

-- Reload save state if level was beaten or mario died
util.loadRandomTrainingSaveState()
LAST_SAVE_STATE_LOAD = STATS.ACTION_COUNTER
else
actions.endAllActions()
actions.startAction(state.action)
end
end

actions.endAllActions()
util.loadRandomTrainingSaveState()
util.setGameSpeedToVeryFast()
states.fillWithEmptyStates()
gui.repaint()
32 changes: 20 additions & 12 deletions util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,28 @@ function util.toImageDimensions(img, dimensions)
return img
end

-- Take a screenshot of the game and return it as a tensor.
-- TODO no longer used?
function getScreen()
local fp = SCREENSHOT_FILEPATH
gui.screenshot(fp)
local screen = image.load(fp, 3, "float"):clone()
screen = image.scale(screen, IMG_DIMENSIONS[2], IMG_DIMENSIONS[3]):clone()
if IMG_DIMENSIONS[1] == 1 then
screen = util.rgb2y(screen)
end
return screen
end

-- Take a screenshot of the game and return it jpg-compressed as a tensor.
function getScreenCompressed()
local fp = SCREENSHOT_FILEPATH
gui.screenshot(fp)
return util.loadJPGCompressed(fp, IMG_DIMENSIONS[1], IMG_DIMENSIONS[2], IMG_DIMENSIONS[3])
end

function util.loadJPGCompressed(fp, channels, height, width)
-- from https://github.com/torch/image/blob/master/doc/saveload.md
--[[
local fin = torch.DiskFile(fp, 'r')
fin:binary()
fin:seekEnd()
local file_size_bytes = fin:position() - 1
fin:seek(1)
local img_binary = torch.ByteTensor(file_size_bytes)
fin:readByte(img_binary:storage())
fin:close()
-- Then when you're ready to decompress the ByteTensor:
im = image.decompressJPG(img_binary, 3)
--]]
local im = image.load(fp, 3, "float")
local c, h, w = im:size(1), im:size(2), im:size(3)
im = im[{{1,c}, {30,h}, {1,w}}] -- cut off 30px from the top
Expand Down

0 comments on commit 7069ae7

Please sign in to comment.