diff --git a/memory_sqlite.lua b/memory_sqlite.lua new file mode 100644 index 0000000..710e71d --- /dev/null +++ b/memory_sqlite.lua @@ -0,0 +1,418 @@ +require 'torch' +require 'paths' +local driver = require 'luasql.sqlite3' + +local env = assert(driver.sqlite3()) +-- connect to data source +local con = assert(env:connect("learned/memory.sqlite")) + +assert(con:execute[[ + CREATE TABLE IF NOT EXISTS states( + id INTEGER(12) PRIMARY KEY, + screen_jpg BLOB, + score INTEGER(12), + count_lifes INTEGER(6), + level_beaten_status INTEGER(6), + mario_game_status INTEGER(6), + player_x INTEGER(6), + mario_image INTEGER(6), + is_level_ending INTEGER(1), + action_arrow INTEGER(2), + action_button INTEGER(2), + reward_score_diff REAL, + reward_x_diff REAL, + reward_level_beaten REAL, + reward_expected_gamma REAL, + reward_expected_gamma_raw REAL, + reward_observed_gamma REAL, + is_dummy INTEGER(1), + is_validation INTEGER(1) + ) +]]) + +--res = assert(con:execute[[ +-- CREATE TABLE memories( +-- state_chain_id INT(12), +-- state_chain_pos INT(4), +-- state_id INT(12) +-- ) +--]]) + +local memory = {} +memory.MEMORY_MAX_SIZE_TRAINING = 120000 +memory.MEMORY_MAX_SIZE_VALIDATION = 10000 +memory.MEMORY_TRAINING_MIN_SIZE = 100 + +function memory.load() +end + +function memory.save() +end + +function memory.getCountEntries(validation) + assert(validation == true or validation == false) + local val_int = 0 + if validation then val_int = 1 end + local cur = assert(con:execute(string.format("SELECT COUNT(*) AS c FROM states WHERE is_validation = %d", val_int))) + local row = cur:fetch({}, "a") + local count = row.c + return count +end + +function memory.getCountAllEntries() + local cur = assert(con:execute(string.format("SELECT COUNT(*) AS c FROM states"))) + local row = cur:fetch({}, "a") + local count = row.c + return count +end + +function memory.getMaxStateId(defaultVal) + if memory.getCountAllEntries() == 0 then + return defaultVal + else + local cur = assert(con:execute("SELECT MAX(id) AS id FROM states")) + local row = cur:fetch({}, "a") + local id = row.id + return id + end +end + +function memory.isTrainDataFull() + return memory.getCountEntries(false) == memory.MEMORY_MAX_SIZE_TRAINING +end + +function memory.reevaluateRewards() +end + +function memory.reorderByDirectReward() +end + +function memory.reorderBySurprise() +end + +function memory.reorderBy(func) +end + +function memory.addEntry(stateChain, nextState, validation) + --[[ + assert(validation == true or validation == false) + for i=1,#stateChain do + memory.insertState(stateChain[i], validation, false) + end + memory.insertState(nextState, validation, false) + + if torch.rand() < 1/1000 do + memory.reduceToMaxSizes() + end + --]] +end + +function memory.addState(state, validation) + assert(validation == true or validation == false) + memory.insertState(state, validation, false) + if math.random() < 1/1000 then + print("Reducing to max size...") + memory.reduceToMaxSizes() + end +end + +function memory.insertState(state, validation, updateIfExists) + assert(state.action ~= nil) + assert(state.reward ~= nil) + --print("X") + local ifExistsCommand = "IGNORE" + if updateIfExists == true then ifExistsCommand = "UPDATE" end + local screen_jpg_serialized = torch.serialize(state.screen, "ascii") + local isLevelEnding_int = 0 + if state.isLevelEnding then isLevelEnding_int = 1 end + local dummy_int = 0 + if state.isDummy then dummy_int = 1 end + local val_int = 0 + if validation then val_int = 1 end + + --print("A") + local query = string.format( + [[ + INSERT OR %s INTO states ( + id, + screen_jpg, + score, + count_lifes, + level_beaten_status, + mario_game_status, + player_x, + mario_image, + is_level_ending, + action_arrow, + action_button, + reward_score_diff, + reward_x_diff, + reward_level_beaten, + reward_expected_gamma, + reward_expected_gamma_raw, + reward_observed_gamma, + is_dummy, + is_validation + ) + VALUES + ( + %d, + '%s', + %d, + %d, + %d, + %d, + %d, + %d, + %d, + %d, + %d, + %.8f, + %.8f, + %.8f, + %.8f, + %.8f, + %.8f, + %d, + %d + ) + ]], + ifExistsCommand, + state.id, + con:escape(screen_jpg_serialized), + state.score, + state.countLifes, + state.levelBeatenStatus, + state.marioGameStatus, + state.playerX, + state.marioImage, + isLevelEnding_int, + state.action.arrow, + state.action.button, + state.reward.scoreDiffReward, + state.reward.xDiffReward, + state.reward.levelBeatenReward, + state.reward.expectedGammaReward, + state.reward.expectedGammaRewardRaw, + state.reward.observedGammaReward, + dummy_int, + val_int + ) + --print("B") + + --local query2 = con:prepare(string.format( + --[[ + INSERT OR %s states ( + id, + screen_jpg, + score, + count_lifes, + level_beaten_status, + mario_game_status, + player_x, + mario_image, + is_level_ending, + action_arrow, + action_button, + reward_score_diff, + reward_x_diff, + reward_level_beaten, + reward_expected_gamma, + reward_expected_gamma_raw, + reward_observed_gamma, + is_dummy, + is_validation + ) + VALUES + ( + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ? + ) + --]]--, ifExistsCommand) + --) + + --[[ + query2:bind({"INTEGER", state.id}) + query2:bind({"BLOB", screen_jpg_serialized}) + query2:bind({"INTEGER", state.score}) + query2:bind({"INTEGER", state.countLifes}) + query2:bind({"INTEGER", state.levelBeatenStatus}) + query2:bind({"INTEGER", state.marioGameStatus}) + query2:bind({"INTEGER", state.playerX}) + query2:bind({"INTEGER", state.marioImage}) + query2:bind({"INTEGER", isLevelEnding_int}) + query2:bind({"INTEGER", state.action.arrow}) + query2:bind({"INTEGER", state.action.button}) + query2:bind({"REAL", state.reward.scoreDiffReward}) + query2:bind({"REAL", state.reward.xDiffReward}) + query2:bind({"REAL", state.reward.levelBeatenReward}) + query2:bind({"REAL", state.reward.expectedGammaReward}) + query2:bind({"REAL", state.reward.expectedGammaRewardRaw}) + query2:bind({"REAL", state.reward.observedGammaReward}) + query2:bind({"INTEGER", dummy_int}) + query2:bind({"INTEGER", val_int}) + --]] + + --print(query) + --if STATS.ACTION_COUNTER < 10 then + assert(con:execute(query)) + --end + --print("C") +end + +function memory.reduceToMaxSizes() + local function deleteSubf(count, maxCount, validation) + local val_int = 0 + if validation then val_int = 1 end + local toDelete = count - maxCount + print(string.format("count=%d, max count=%d, toDelete=%d, val_int=%d", count, maxCount, toDelete, val_int)) + if toDelete > 0 then + local cur = assert(con:execute(string.format("SELECT id FROM states WHERE is_validation = %d ORDER BY id ASC LIMIT %d", val_int, toDelete))) + local row = cur:fetch({}, "a") + while row do + assert(con:execute(string.format("DELETE FROM states WHERE id = %d", row.id))) + row = cur:fetch(row, "a") + end + end + end + + deleteSubf(memory.getCountEntries(false), memory.MEMORY_MAX_SIZE_TRAINING, false) + deleteSubf(memory.getCountEntries(true), memory.MEMORY_MAX_SIZE_VALIDATION, true) +end + +function memory.removeRandomEntry(validation) +end + +function memory.removeRandomEntries(nRemove, validation, skew) +end + +function memory.getRandomWeightedIndex(skew, validation, skewStrength) + +end + +function memory.getRandomStateChain(length, validation) + assert(memory.getCountEntries(validation) >= length) + local val_int = 0 + if validation then val_int = 1 end + local cur = assert(con:execute(string.format("SELECT id FROM states WHERE is_validation = %d ORDER BY random() LIMIT 1", val_int))) + local row = cur:fetch({}, "a") + local id = row.id + local indices = {} + for i=1,length do table.insert(indices, id + i -1) end + local states = memory.getStatesByIndices(indices) + if #states < length then + return memory.getRandomStateChain(length, validation) + else + return states + end +end + +function memory.getStatesByIndices(indices) + local indicesStr = "" + local indexToState = {} + for i=1,#indices do + if i == 1 then + indicesStr = "" .. indices[i] + else + indicesStr = indicesStr .. ", " .. indices[i] + end + indexToState[indices[i]] = false + end + local cur = assert(con:execute(string.format("SELECT * FROM states WHERE id IN (%s)", indicesStr))) + local row = cur:fetch({}, "a") + while row do + local state = memory.rowToState(row) + assert(indexToState[state.id] ~= nil) + indexToState[state.id] = state + row = cur:fetch(row, "a") + end + local states = {} + for k,v in pairs(indexToState) do + table.insert(states, v) + end + return states +end + +function memory.getTrainingBatch(batchSize) + return memory.getBatch(batchSize, false) +end + +function memory.getValidationBatch(batchSize) + return memory.getBatch(batchSize, true) +end + +function memory.getBatch(batchSize, validation, reevaluate) + assert(validation == true or validation == false) + assert(reevaluate == true or reevaluate == false) + local stateChains = {} + local length = STATES_PER_EXAMPLE + if reevaluate then length = length + 1 end + for i=1,batchSize do + --local idx = memory.getRandomWeightedIndex("top", validation) + local stateChain = memory.getRandomStateChain(length, validation) + table.insert(stateChains, stateChain) + end + + if not reevaluate then + local batchInput, batchTarget = network.stateChainsToBatch(stateChains) + return batchInput, batchTarget + else + local stateChainsSlim = {} + local stateChainsNext = {} + for i=1,#stateChain do + local stateChain = stateChains[i] + local stateChainCurrent = {} + local stateChainNext = {} + for j=1,length-1 do + table.insert(stateChainCurrent, stateChain[j]) + table.insert(stateChainNext, stateChain[j+1]) + end + table.insert(stateChainsCurrent, stateChainCurrent) + table.insert(stateChainsNext, stateChainNext) + end + local bestActions = network.approximateBestActionsBatch(stateChainsNext) + for i=1,#stateChains do + local oldReward = stateChainsCurrent[i][length].reward + local newReward = rewards.statesToReward(stateChainsNext[i], bestActions[i].action, bestActions[i].value) + newReward.observedGammaReward = oldReward.observedGammaReward + stateChainsCurrent[i][length].reward = newReward + end + + local batchInput, batchTarget = network.stateChainsToBatch(stateChainsCurrent) + return batchInput, batchTarget, stateChainsCurrent + end +end + +function memory.rowToState(row) + local screen_jpg = torch.deserialize(row.screen_jpg, "ascii") + local isDummy = false + if row.is_dummy == 1 then isDummy = true end + local action = Action.new(row.action_arrow, row.action_button) + local reward = Reward.new(row.reward_score_diff, row.reward_x_diff, row.reward_level_beaten, row.reward_expected_gamma, row.reward_expected_gamma_raw, row.reward_observed_gamma) + local state = State.new(row.id, row.score, row.count_lifes, row.level_beaten_status, row.mario_game_status, row.player_x, row.mario_image, row.is_level_ending, action, reward) + state.isDummy = isDummy + return state +end + +function memory.plot(subsampling) + +end + +return memory diff --git a/states_sqlite.lua b/states_sqlite.lua new file mode 100644 index 0000000..91652b8 --- /dev/null +++ b/states_sqlite.lua @@ -0,0 +1,374 @@ +local states = { + data = {}, + dataAll = {} +} + +states.MAX_SIZE = 5000 +states.CASCADE_BACK = 100 +--states.CASCADE_STOP = 0.05 +--states.CASCADE_INFLUENCE = 0.25 +--states.idx = 0 + +--[[ +function states.addEntry(screen, action, bestActionValue, score, playerX, countLifes, levelBeatenStatus, marioGameStatus) + states.idx = states.idx + 1 + local entry = { + idx=states.idx, + screen=screen, + action=action, + bestActionValue=bestActionValue, + score=score, + playerX=playerX, + countLifes=countLifes, + levelBeatenStatus=levelBeatenStatus, + marioGameStatus=marioGameStatus, + reward=nil + } + table.insert(states.data, entry) + --print("Added entry ", entry.idx, " to states, now", #states.data) + --print("First entry idx:", states.data[1].idx) + --print("Last entry idx:", states.data[#states.data].idx) + if #states.data > states.MAX_SIZE then + table.remove(states.data, 1) + --print("Removed entry from states, now", #states.data) + end + return entry +end +--]] +function states.addEntry(pState) + table.insert(states.data, pState) + table.insert(states.dataAll, pState) + if #states.data > states.MAX_SIZE then + table.remove(states.data, 1) + end + if #states.dataAll > states.MAX_SIZE then + table.remove(states.dataAll, 1) + end +end + +--[[ +function states.setLastReward(reward) + local entry = states.getLastEntry() + if entry == nil then + print("[INFO] could not set reward of last state, because there is no last state") + else + --print("Setting reward of ", entry.idx, " to ", reward) + entry.reward = reward + end +end +--]] + +function states.getLastEntry() + if #states.data > 0 then + return states.data[#states.data] + else + return nil + end +end + +function states.getEntry(index) + if index == 0 then + error("index<=0") + elseif index > 0 then + return states.data[index] + elseif index < 0 then + --print("#states.data", #states.data) + --print("returning index ", #states.data - math.abs(index+1)) + --print("is nil:", states.data[#states.data - math.abs(index+1)] == nil) + return states.data[#states.data - math.abs(index+1)] + end +end + +function states.getLastEntries(nEntries) + assert(#states.data >= nEntries, string.format("%d requested, %d available", nEntries, #states.data)) + local result = {} + for i=1,nEntries do + table.insert(result, states.data[#states.data-nEntries+i]) + end + return result +end + +--[[ +function states.cascadeBackReward(reward) + --print(string.format("cascading reward %.2f", rewards.getDirectReward(reward))) + local last = math.max(#states.data - 1, 1) + local first = math.max(last - states.CASCADE_BACK + 1, 1) + local counter = last - first + local direct = rewards.getDirectReward(reward) + for i=first,last do + local s = states.data[i] + if not s.isDummy then + local oldGamma = s.reward.observedGammaReward + local cascadedGamma = torch.pow(GAMMA_OBSERVED, counter+1) * direct + --local cascadedGamma = GAMMA_OBSERVED * 1/(counter+1) * direct + --local newGamma = states.CASCADE_INFLUENCE * (cascadedGamma) + (1 - states.CASCADE_INFLUENCE) * oldGamma + --local newGamma = oldGamma + cascadedGamma + local newGamma = oldGamma + cascadedGamma + print(string.format("Cascade %.2f at i=%d/c=%d from %.4f to %.4f by %.4f", direct, i, counter, oldGamma, newGamma, cascadedGamma)) + s.reward.observedGammaReward = newGamma + counter = counter - 1 + --if math.abs(cascadedGamma) < states.CASCADE_STOP then + -- break + --end + end + end +end +--]] + +function states.cascadeBackReward(reward) + --print(string.format("cascading reward %.2f", rewards.getDirectReward(reward))) + local epsilon = 0.0001 + local last = math.max(#states.data - 1, 1) + local first = math.max(last - states.CASCADE_BACK + 1, 1) + local exp = last - first + 1 + local direct = rewards.getDirectReward(reward) + -- TODO start at the last state, move towards the first state, break if propagated reward is lower than epsilon, + -- should be faster + if direct > epsilon or direct < (-1)*epsilon then + for i=first,last do + local s = states.data[i] + if not s.isDummy then + local oldGamma = s.reward.observedGammaReward + local cascadedGamma = torch.pow(GAMMA_OBSERVED, exp) * direct + --local cascadedGamma = GAMMA_OBSERVED * 1/(counter+1) * direct + --local newGamma = states.CASCADE_INFLUENCE * (cascadedGamma) + (1 - states.CASCADE_INFLUENCE) * oldGamma + --local newGamma = oldGamma + cascadedGamma + local newGamma = oldGamma + cascadedGamma + --print(string.format("Cascade %.2f at i=%d/c=%d from %.4f to %.4f by %.4f", direct, i, exp, oldGamma, newGamma, cascadedGamma)) + s.reward.observedGammaReward = newGamma + --if math.abs(cascadedGamma) < states.CASCADE_STOP then + -- break + --end + end + exp = exp - 1 + end + end +end + +function states.decompressScreen(screen) + return util.decompressJPG(screen) +end + +function states.addToMemory() + print("start addToMemory") + -- -1 because the last state doesnt have a reward and action yet + for i=1,#states.data-1 do + local state = states.data[i] + local id = state.id + local id1000 = id % 1000 + local val = false + if (id1000 >= 750 and id1000 < 800) or (id1000 >= 950 and id1000 < 1000) then + val = true + end + memory.addState(state, val) + end + print("end addToMemory") +end + +function states.clear(refill) + states.data = {} + if refill or refill == nil then + states.fillWithEmptyStates() + end +end + +function states.fillWithEmptyStates(minNumberOfStates) + minNumberOfStates = minNumberOfStates or (STATES_PER_EXAMPLE+1) + for i=1,minNumberOfStates do + --table.insert(states.data, states.createEmptyState()) + states.addEntry(states.createEmptyState()) + end +end + +function states.createEmptyState() + local screen = torch.zeros(3, IMG_DIMENSIONS[2], IMG_DIMENSIONS[3]) + screen = image.drawText(screen, string.format("F%d", STATS.FRAME_COUNTER), 0, 0, {color={255,255,255}}) + if IMG_DIMENSIONS[1] == 1 then + screen = util.rgb2y(screen) + end + local screenCompressed = util.compressJPG(screen) + local score = 0 + local countLifes = 0 + local levelBeatenStatus = 0 + local marioGameStatus = 0 + local playerX = 0 + local marioImage = 0 + local isLevelEnding = false + local action = actions.createRandomAction() + local reward = Reward.new(0, 0, 0, 0, 0, 0) + local s = State.new(nil, screenCompressed, score, countLifes, levelBeatenStatus, marioGameStatus, playerX, marioImage, isLevelEnding, action, reward) + s.isDummy = true + return s +end + +--[[ +function states.plotRewards() + local points = {} + for i=1,#states.dataAll do + local state = states.dataAll[i] + local r = state.reward + -- the newest added state should not have any reward yet + if r == nil then + table.insert(points, {i, 0, 0, 0, 0}) + else + table.insert(points, {i, rewards.getDirectReward(r), r.observedGammaReward, r.expectedGammaReward, r.expectedGammaRewardRaw}) + end + end + display.plot(points, {win=21, labels={'State', 'Direct Reward', 'OGR', 'EGR (after gamma multiply)', 'EGR (before gamma multiply)'}, title='Reward per state (direct reward, observed/expected gamma reward)'}) +end +--]] + +function states.plotRewards(nBackMax) + nBackMax = nBackmax or 200 + local points = {} + for i=math.max(#states.dataAll-nBackMax, 1),#states.dataAll do + local state = states.dataAll[i] + local r = state.reward + -- the newest added state should not have any reward yet + if r == nil then + table.insert(points, {i, 0, 0, 0, 0}) + else + table.insert(points, {i, rewards.getDirectReward(r), r.observedGammaReward, r.expectedGammaReward, r.expectedGammaRewardRaw}) + end + end + display.plot(points, {win=21, labels={'State', 'Direct Reward', 'OGR', 'EGR (after gamma multiply)', 'EGR (before gamma multiply)'}, title='Reward per state (direct reward, observed/expected gamma reward)'}) +end + +--[[ +function states.plotStateChain(stateChain, windowId, title, width) + windowId = windowId or 20 + title = title or "State Chain" + + local imgsDisp = torch.zeros(#stateChain, IMG_DIMENSIONS_Q_HISTORY[1], IMG_DIMENSIONS_Q_HISTORY[2], IMG_DIMENSIONS_Q_HISTORY[3]) + for i=1,#stateChain do + imgsDisp[i] = states.decompressScreen(stateChain[i].screen) + end + + local out = image.toDisplayTensor{input=imgsDisp, nrow=#stateChain, padding=1} + + if width then + display.image(out, {win=windowId, width=width, title=title}) + else + display.image(out, {win=windowId, title=title}) + end +end +--]] + +function states.stateChainToImage(stateChain, net) + local batchSize = 1 + local lastState = stateChain[#stateChain] + local previousStates = {} + local previousScreens = {} + for i=1,#stateChain-1 do + table.insert(previousStates, stateChain[i]) + local screen = states.decompressScreen(stateChain[i].screen) + local screenWithAction = torch.zeros(screen:size(1), screen:size(2)+16, screen:size(3)) + screenWithAction[{{1,screen:size(1)}, {1,screen:size(2)}, {1, screen:size(3)}}] = screen + if screen:size(1) == 1 then + screenWithAction = torch.repeatTensor(screenWithAction, 3, 1, 1) + end + local actionStr = actions.actionToString(stateChain[i].action) + local x = 2 + local y = screen:size(2) + 2 + screenWithAction = image.drawText(screenWithAction, actionStr, x, y, {color={255,255,255}}) + screenWithAction = image.scale(screenWithAction, IMG_DIMENSIONS_Q_LAST[2], IMG_DIMENSIONS_Q_LAST[3]) + table.insert(previousScreens, screenWithAction) + end + + local lastScreen = states.decompressScreen(lastState.screen) + + if net ~= nil then + -- Get the transformation matrix from the AffineTransformMatrixGenerator + local transfo = nil + local function findTransformer(m) + local name = torch.type(m) + + if name:find('AffineTransformMatrixGenerator') then + transfo = m + end + end + net:apply(findTransformer) + + if transfo ~= nil then + transfo = transfo.output:float() + + --[[ + print("transfo size", transfo:size(1), transfo:size(2), transfo:size(3)) + print("Transformation matrix values:") + for a=1,transfo:size(1) do + for b=1,transfo:size(2) do + for c=1,transfo:size(3) do + print(string.format("%d %d %d = %.4f", a, b, c, transfo[a][b][c])) + end + end + end + --]] + + local corners = torch.Tensor{{-1,-1,1},{-1,1,1},{1,-1,1},{1,1,1}} + -- Compute the positions of the corners in the original image + local points = torch.bmm(corners:repeatTensor(batchSize,1,1), transfo:transpose(2,3)) + -- Ensure these points are still in the image + local imageSize = lastScreen:size(2) + + --[[ + print("Corner points before fix:") + for batch=1,batchSize do + for pt=1,4 do + local point = points[batch][pt] + print(string.format("(%.4f, %.4f)", point[1], point[2])) + end + end + --]] + + points = torch.floor((points+1)*imageSize/2) + points:clamp(1,imageSize-1) + + --[[ + print("Corner points after fix:") + for batch=1,batchSize do + for pt=1,4 do + local point = points[batch][pt] + print(string.format("(%.4f, %.4f)", point[1], point[2])) + end + end + --]] + + for batch=1,batchSize do + for pt=1,4 do + local point = points[batch][pt] + --print(string.format("p2 %.4f %.4f", point[1], point[2])) + for chan=1,IMG_DIMENSIONS_Q_LAST[1] do + local max_value = lastScreen[chan]:max()*1.1 + -- We add 4 white pixels because one can disappear in image rescaling + lastScreen[chan][point[1]][point[2]] = max_value + lastScreen[chan][point[1]+1][point[2]] = max_value + lastScreen[chan][point[1]][point[2]+1] = max_value + lastScreen[chan][point[1]+1][point[2]+1] = max_value + end + end + end + end + end + + if lastScreen:size(1) == 1 then + lastScreen = torch.repeatTensor(lastScreen, 3, 1, 1) + end + table.insert(previousScreens, lastScreen) + + local result = image.toDisplayTensor{input=previousScreens, nrow=#previousScreens, padding=1} + --local result = image.toDisplayTensor{input={lastScreen}, nrow=1, padding=1} + --[[ + local c = math.max(IMG_DIMENSIONS_Q_LAST[1], IMG_DIMENSIONS_Q_HISTORY[1]) + local h = math.max(prev:size(3)+16, IMG_DIMENSIONS_Q_LAST[2]) + local w = prev:size(2)+IMG_DIMENSIONS_Q_LAST[3] + local result = torch.zeros(c, h, w) + result[{{1,c}, {1,prev:size(2)}, {1,prev:size(3)}}] = prev + for i=1,#previousStates do + image.draw() + end + --]] + + return result +end + +return states diff --git a/train_sqlite.lua b/train_sqlite.lua new file mode 100644 index 0000000..ac03b7e --- /dev/null +++ b/train_sqlite.lua @@ -0,0 +1,524 @@ +print("------------------------") +print("START") +print("------------------------") + +GPU = 0 +SEED = 43 +lsne_memory = memory + +require 'torch' +require 'image' +require 'nn' +require 'optim' +memory = require 'memory_sqlite' +network = require 'network' +actions = require 'actions' +Action = require 'action' +util = require 'util' +states = require 'states_sqlite' +State = require 'state' +rewards = require 'rewards' +Reward = require 'reward' +VAE = require 'VAE' +ForgivingMSECriterion = require 'layers.ForgivingMSECriterion' +ForgivingAbsCriterion = require 'layers.ForgivingAbsCriterion' +ok, display = pcall(require, 'display') +if not ok then print('display not found. unable to plot') end + +require 'cutorch' +require 'cunn' +require 'cudnn' +if GPU >= 0 then + print(string.format("Using gpu device %d", GPU)) + cutorch.setDevice(GPU + 1) + cutorch.manualSeed(SEED) + + -- Saves 40% time according to http://torch.ch/blog/2016/02/04/resnets.html + cudnn.fastest = true + cudnn.benchmark = true +end +math.randomseed(SEED) +torch.manualSeed(SEED) +torch.setdefaulttensortype('torch.FloatTensor') + +FPS = movie.get_game_info().fps +--local fps = 100 +--local reactEveryNthFrame = 6 -- roughly 26/4, i.e. every 0.25s +--REACT_EVERY_NTH_FRAME = math.floor(FPS / 5) +REACT_EVERY_NTH_FRAME = 5 +print(string.format("FPS: %d, Reacting every %d frames", FPS, REACT_EVERY_NTH_FRAME)) + +SCREENSHOT_FILEPATH = "/media/ramdisk/mario-ai-screenshots/current-screen.png" +IMG_DIMENSIONS = {1, 64, 64} +--IMG_DIMENSIONS_Q = {1, 48, 48} +IMG_DIMENSIONS_Q_HISTORY = {1, 32, 32} +IMG_DIMENSIONS_Q_LAST = {1, 64, 64} +IMG_DIMENSIONS_AE = {1, 128, 128} +BATCH_SIZE = 16 +STATES_PER_EXAMPLE = 4 +GAMMA_EXPECTED = 0.9 +GAMMA_OBSERVED = 0.9 +MAX_GAMMA_REWARD = 100 +P_EXPLORE_START = 0.8 +P_EXPLORE_END = 0.1 +P_EXPLORE_END_AT = 400000 +STATS = { + STATE_ID = 0, + FRAME_COUNTER = 0, + ACTION_COUNTER = 0, + CURRENT_DIRECT_REWARD_SUM = 0, + CURRENT_OBSERVED_GAMMA_REWARD_SUM = 0, + AVERAGE_REWARD_DATA = {}, + AVERAGE_LOSS_DATA = {}, + LAST_BEST_ACTION_VALUE = 0, + P_EXPLORE_CURRENT = P_EXPLORE_START +} +LAST_SAVE_STATE_LOAD = 0 + +print("Loading/Creating network...") +Q_L2_NORM = 1e-6 +Q_CLAMP = 10 +Q = network.createOrLoadQ() + +PARAMETERS, GRAD_PARAMETERS = Q:getParameters() +--CRITERION = nn.ForgivingAbsCriterion() +CRITERION = nn.ForgivingMSECriterion() +--CRITERION = nn.MSECriterion() +OPTCONFIG = {learningRate=0.001, beta1=0.9, beta2=0.999} +DECAY = 1.0 +--OPTCONFIG = {learningRate=0.001, momentum=0.9} +OPTSTATE = {} + +--MODEL_AE, CRITERION_AE_LATENT, CRITERION_AE_RECONSTRUCTION, PARAMETERS_AE, GRAD_PARAMETERS_AE = VAE.createVAE() +--OPTCONFIG_AE = { learningRate=0.001 } +--OPTSTATE_AE = {} + +--print("Number of parameters in AE:", network.getNumberOfParameters(MODEL_AE)) +print("Number of parameters in Q:", network.getNumberOfParameters(Q)) + +print("Loading memory...") +memory.load() + +print("Loading stats...") +util.loadStats() +STATS.STATE_ID = memory.getMaxStateId(1) + +print("Starting loop.") + +--[[ +local points = {} +for i=1,10000 do + table.insert(points, math.random()*math.random()*math.random()) +end +table.sort(points) +points2 = {} +for i=1,10000 do + table.insert(points2, {i, points[i]}) +end +display.plot(points2, {win=5, labels={'entry', 'val'}, title='Random values'}) +--]] + +function on_paint() + local lastLastState = states.getEntry(-2) + local lastState = states.getEntry(-1) + + -- last best action value + --[[ + if lastLastState ~= nil then + gui.text(1, 1, string.format("LLRew: %.2f/%.2f", rewards.getDirectReward(lastLastState.reward), lastLastState.reward.expectedGammaReward)) + gui.text(1+175, 1, string.format("LLBAV: %.2f", STATS.LAST_BEST_ACTION_VALUE or 0)) + end + --]] + gui.text(1+350-15, 1, string.format("Memory: %d/%d", memory.getCountEntries(false), memory.getCountEntries(true))) + + --[[ + local observedGammaRewards = "oGR: " + local nStates = #states.data + for i=1,math.min(8, nStates) do + local s = states.data[nStates-i+1] + local ogr = s.reward and s.reward.observedGammaReward or 0 -- the last (most recent) state does not have a reward yet + observedGammaRewards = observedGammaRewards .. string.format("%.2f ", ogr) + end + gui.text(1, 15, observedGammaRewards) + --]] +end + +function on_frame_emulated() + --print("Start") + local lastLastState = states.getEntry(-2) + local lastState = states.getEntry(-1) + STATS.FRAME_COUNTER = movie.currentframe() + + -- last best action value + --[[ + if lastLastState ~= nil then + gui.text(1, 1, string.format("LLRew: %.2f/%.2f", rewards.getDirectReward(lastLastState.reward), rewards.getSum(lastLastState.reward))) + gui.text(1+175, 1, string.format("LLBAV: %.2f", LAST_BEST_ACTION_VALUE or 0)) + end + gui.text(1+350, 1, string.format("Memory: %d", #memory.data)) + --]] + + --if (STATS.FRAME_COUNTER+1) % REACT_EVERY_NTH_FRAME == 0 then + -- actions.endAllActions() + -- return + if STATS.FRAME_COUNTER % REACT_EVERY_NTH_FRAME ~= 0 then + return + end + + STATS.ACTION_COUNTER = STATS.ACTION_COUNTER + 1 + + if STATS.ACTION_COUNTER % 4000 == 0 then + print("Garbage collection...") + collectgarbage() + util.sleep(2) + end + + if lastState == nil then print("[NOTE] lastState is nil") end + if lastLastState == nil then print("[NOTE] lastLastState is nil") end + + --print("Before new state") + local state = State.new(nil, getScreenCompressed(), util.getCurrentScore(), util.getCountLifes(), util.getLevelBeatenStatus(), util.getMarioGameStatus(), util.getPlayerX(), util.getMarioImage(), util.isLevelEnding()) + states.addEntry(state) -- getLastEntries() depends on this, don't move it after the next code block + --print("Score:", score, "Level:", util.getLevel(), "x:", playerX, "status:", marioGameStatus, "levelBeatenStatus:", levelBeatenStatus, "count lifes:", countLifes) + --print("Mario Image", util.getMarioImage()) + + -- Calculate reward + local rew, bestAction, bestActionValue = rewards.statesToReward(states.getLastEntries(STATES_PER_EXAMPLE)) + lastState.reward = rew + --print(string.format("[Reward] R=%.2f DR=%.2f SDR=%.2f XDR=%.2f LBR=%.2f EGR=%.2f", rewards.getSumExpected(lastState.reward), rewards.getDirectReward(lastState.reward), lastState.reward.scoreDiffReward, lastState.reward.xDiffReward, lastState.reward.levelBeatenReward, lastState.reward.expectedGammaReward)) + --print(string.format("Cascading reward %.4f", rewards.getDirectReward(lastState.reward))) + states.cascadeBackReward(lastState.reward) + STATS.LAST_BEST_ACTION_VALUE = bestActionValue + + -- Add to memory + --[[ + local pastStates = states.getLastEntries(STATES_PER_EXAMPLE+1) + table.remove(pastStates) -- pop last state as that state is currently still in progress (eg no reward yet) + local ac1000 = STATS.ACTION_COUNTER % 1000 + local validation = (ac1000 >= 750 and ac1000 < 800) or (ac1000 >= 950 and ac1000 < 1000) + memory.addEntry(pastStates, state, validation) + --]] + + -- show state chain + -- must happen before training as it might depend on network's current output + --print("Before display") + display.image(states.stateChainToImage(states.getLastEntries(STATES_PER_EXAMPLE), Q), {win=17, title="Last states"}) + + --print("Before plot") + -- plot average rewards + if STATS.ACTION_COUNTER % 3 == 0 then + states.plotRewards() + end + + -------------------- + + --[[ + if STATS.ACTION_COUNTER % 250 == 0 then + table.insert(STATS.AVERAGE_REWARD_DATA, {STATS.ACTION_COUNTER, STATS.CURRENT_DIRECT_REWARD_SUM / 50, STATS.CURRENT_OBSERVED_GAMMA_REWARD_SUM / 50}) + STATS.CURRENT_DIRECT_REWARD_SUM = 0 + STATS.CURRENT_OBSERVED_GAMMA_REWARD_SUM = 0 + plotAverageReward() + end + --]] + --print("Before plotAvg") + if STATS.ACTION_COUNTER % states.MAX_SIZE == 0 then + local directRewardSum = 0 + local observedGammaRewardSum = 0 + local expectedGammaRewardSum = 0 + for i=1,#states.dataAll do + if states.dataAll[i].reward ~= nil then + directRewardSum = directRewardSum + rewards.getDirectReward(states.dataAll[i].reward) + observedGammaRewardSum = observedGammaRewardSum + states.dataAll[i].reward.observedGammaReward + expectedGammaRewardSum = expectedGammaRewardSum + states.dataAll[i].reward.expectedGammaReward + end + end + table.insert(STATS.AVERAGE_REWARD_DATA, {STATS.ACTION_COUNTER, directRewardSum / #states.dataAll, observedGammaRewardSum / #states.dataAll, expectedGammaRewardSum / #states.dataAll}) + plotAverageReward() + end + + if STATS.ACTION_COUNTER % 10000 == 0 then + --print("Reevaluating rewards in memory...") + --memory.reevaluateRewards() + --print("Reordering memory...") + --memory.reorderByDirectReward() + --print("Plotting memory...") + --memory.plot(10) + --print("Reordering finished") + --print(string.format("1st: %.4f", memory.data[1][2].reward and rewards.getSumExpected(memory.data[1][2].reward) or 123.456)) + --print(string.format("2st: %.4f", memory.data[2][2].reward and rewards.getSumExpected(memory.data[2][2].reward) or 123.456)) + --print(string.format("3st: %.4f", memory.data[3][2].reward and rewards.getSumExpected(memory.data[3][2].reward) or 123.456)) + --print(string.format("4st: %.4f", memory.data[4][2].reward and rewards.getSumExpected(memory.data[4][2].reward) or 123.456)) + --print(string.format("5st: %.4f", memory.data[5][2].reward and rewards.getSumExpected(memory.data[5][2].reward) or 123.456)) + --print("Done.") + end + + --print("Before train") + if (STATS.ACTION_COUNTER == 250 and memory.isTrainDataFull()) + or STATS.ACTION_COUNTER % 5000 == 0 then + --print("Training AE...") + --for i=1,25 do + -- trainAE() + --end + + print("Training...") + local nTrainingBatches = 3500 --math.max(math.floor(#memory.trainData / BATCH_SIZE), 51) + local nTrainingGroups = 50 -- number of plot points per training epoch + local nTrainBatchesPerGroup = math.floor(nTrainingBatches / nTrainingGroups) + local nValBatchesPerGroup = math.floor(nTrainBatchesPerGroup * 0.10) + for i=1,nTrainingGroups do + local sumLossTrain = 0 + local sumLossVal = 0 + for j=1,nTrainBatchesPerGroup do + local loss = trainOneBatch() + sumLossTrain = sumLossTrain + loss + print(string.format("[BATCH %d/%d] loss=%.8f", (i-1)*nTrainBatchesPerGroup + j, nTrainingBatches, loss)) + end + for j=1,nValBatchesPerGroup do + sumLossVal = sumLossVal + valOneBatch() + end + table.insert(STATS.AVERAGE_LOSS_DATA, {#STATS.AVERAGE_LOSS_DATA+1, sumLossTrain/nTrainBatchesPerGroup, sumLossVal/nValBatchesPerGroup}) + plotAverageLoss() + end + + --[[ + print("Training...") + local nBatchesTrain = 2500 + local nBatchesVal = 250 + local sumLossTrain = 0 + local sumLossVal = 0 + for i=1,nBatchesTrain do + local loss = trainOneBatch() + sumLossTrain = sumLossTrain + loss + print(string.format("[BATCH] loss=%.8f", loss)) + end + for i=1,nBatchesVal do + sumLossVal = sumLossVal + valOneBatch() + end + table.insert(STATS.AVERAGE_LOSS_DATA, {STATS.ACTION_COUNTER, sumLossTrain/nBatchesTrain, sumLossVal/nBatchesVal}) + plotAverageLoss() + --]] + + OPTCONFIG.learningRate = OPTCONFIG.learningRate * DECAY + print(string.format("[LEARNING RATE] %.12f", OPTCONFIG.learningRate)) + end + + --print("Before chooseAction") + --print("bestAction:", bestAction, bestAction.arrow, bestAction.button) + state.action = chooseAction(states.getLastEntries(STATES_PER_EXAMPLE), false, bestAction) + --states.addEntry(screen, action, bestActionValue, score, playerX, countLifes, levelBeatenStatus, marioGameStatus) + + + --print("Before level ended") + local levelEnded = state.levelBeatenStatus == 128 or state.marioGameStatus == 2 + if levelEnded or (STATS.ACTION_COUNTER - LAST_SAVE_STATE_LOAD) > 1000 then + print("Reloading saved gamestate and saving states...") + states.addToMemory() + states.clear() + + -- Reload save state if level was beaten or mario died + util.loadRandomTrainingSaveState() + LAST_SAVE_STATE_LOAD = STATS.ACTION_COUNTER + else + actions.endAllActions() + actions.startAction(state.action) + end + + --print("Before decay") + -- decay exploration rate + local pPassed = math.min(STATS.ACTION_COUNTER / P_EXPLORE_END_AT, 1.0) + STATS.P_EXPLORE_CURRENT = (1-pPassed) * P_EXPLORE_START + pPassed * P_EXPLORE_END + if STATS.ACTION_COUNTER % 250 == 0 then + print(string.format("[EXPLORE P] %.2f", STATS.P_EXPLORE_CURRENT)) + end + + --print("Before save") + -- save + if STATS.ACTION_COUNTER % 10000 == 0 then + print("Saving stats...") + util.saveStats() + print("Saving network...") + network.save() + --print("Saving memory...") + --memory.save() + + --[[ + print("Clearing data and reloading...") + states.data = {} + states.dataAll = {} + memory.valData = {} + memory.trainData = {} + collectgarbage() + util.sleep(1) + memory.load() + collectgarbage() + util.sleep(1) + states.clear() -- refills with empty states + --]] + end +end + +function plotAverageReward() + display.plot(STATS.AVERAGE_REWARD_DATA, {win=3, labels={'action counter', 'direct', 'observed gamma', 'expected gamma'}, title='Average rewards per N actions'}) +end + +function plotAverageLoss() + display.plot(STATS.AVERAGE_LOSS_DATA, {win=4, labels={'batch group', 'training', 'validation'}, title='Average loss per batch'}) +end + +function getScreen() + -------------------- + -- Estimate the reward of the chosen action, + -- add it to memory + -------------------- + -- Current State + -- screenshot_bitmap() => DBITMAP + -- DBITMAP members: + -- blit_scaled + -- blit_porterduff + -- __gc + -- draw_clip + -- pset + -- __newindex + -- size + -- hash + -- blit_scaled_porterduff + -- draw + -- draw_clip_outside + -- blit + -- save_png + -- adjust_transparency + -- pget + -- draw_outside + -- __index + --local screen = gui.screenshot_bitmap() + local fp = "/media/ramdisk/mario-ai-screenshots/current-screen.jpeg" + gui.screenshot(fp) + local screen = image.load(fp, 3, "float"):clone() + screen = image.scale(screen, IMG_DIMENSIONS[2], IMG_DIMENSIONS[3]):clone() + if IMG_DIMENSIONS[1] == 1 then + screen = util.rgb2y(screen) + end + return screen +end + +function getScreenCompressed() + local fp = SCREENSHOT_FILEPATH + gui.screenshot(fp) + return util.loadJPGCompressed(fp, IMG_DIMENSIONS[1], IMG_DIMENSIONS[2], IMG_DIMENSIONS[3]) +end + +function trainAE() + local batchInput = memory.getAutoencoderBatch(BATCH_SIZE) + VAE.train(batchInput, MODEL_AE, CRITERION_AE_LATENT, CRITERION_AE_RECONSTRUCTION, PARAMETERS_AE, GRAD_PARAMETERS_AE, OPTCONFIG_AE, OPTSTATE_AE) +end + +function trainOneBatch() + -- Train + --FRAME_COUNTER % BATCH_SIZE == 0 + if memory.getCountEntries(false) >= memory.MEMORY_TRAINING_MIN_SIZE then + local batchInput, batchTarget, stateChains = memory.getBatch(BATCH_SIZE, false, true) + --print("Training with a batch of size " .. batchInput:size(1)) + local loss = network.forwardBackwardBatch(batchInput, batchTarget) + display.image(states.stateChainToImage(stateChains[1], Q), {win=18, title="Last training batch 1st example"}) + return loss + else + return 0 + end +end + +function valOneBatch() + if memory.getCountEntries(true) >= BATCH_SIZE then + local batchInput, batchTarget = memory.getValidationBatch(BATCH_SIZE) + return network.batchToLoss(batchInput, batchTarget) + else + return 0 + end +end + +--[[ +function chooseAction(lastState, state, perfect, bestAction) + perfect = perfect or false + local _action, _actionValue + if lastState == nil or state == nil or (not perfect and math.random() < STATS.P_EXPLORE_CURRENT) then + _action = util.getRandomEntry(actions.ACTIONS_NETWORK) + --print("Chossing action randomly:", _action) + else + if bestAction ~= nil then + _action = bestAction + else + -- Use network to approximate action with maximal value + _action, _actionValue = network.approximateBestAction(lastState, state) + --print("Q approximated action:", _action, actions.ACTION_TO_BUTTON_NAME[_action]) + end + end + + return _action +end +--]] + +function chooseAction(lastStates, perfect, bestAction) + perfect = perfect or false + local _action, _actionValue + if not perfect and math.random() < STATS.P_EXPLORE_CURRENT then + if bestAction == nil or math.random() < 0.5 then + -- randomize both + _action = Action.new(util.getRandomEntry(actions.ACTIONS_ARROWS), util.getRandomEntry(actions.ACTIONS_BUTTONS)) + else + -- randomize only arrow or only button + if math.random() < 0.5 then + _action = Action.new(util.getRandomEntry(actions.ACTIONS_ARROWS), bestAction.button) + else + _action = Action.new(bestAction.arrow, util.getRandomEntry(actions.ACTIONS_BUTTONS)) + end + end + --print("Chossing action randomly:", _action) + else + if bestAction ~= nil then + _action = bestAction + else + -- Use network to approximate action with maximal value + _action, _actionValue = network.approximateBestAction(lastStates) + --print("Q approximated action:", _action, actions.ACTION_TO_BUTTON_NAME[_action]) + end + end + + return _action +end + +--setToFastSpeed() +--for lcid=1,8 do +-- local port, controller = input.lcid_to_pcid2(lcid) +-- print("controller_info:", port, controller, input.controller_info(port, controller)) +--end + +--print("controller type 0,0:", input.controllertype(0, 0)) +--print("controller type 0,1:", input.controllertype(0, 1)) +--print("controller type 1,0:", input.controllertype(1, 0)) +--print("controller type 1,1:", input.controllertype(1, 1)) +--print("controller_info A:", get_controller_info()) +--print("controller_info B:", 1, 0, input.controller_info(1, 0)) +--for i=0,128 do +-- print("input.joyget("..i.."):", input.joyget(i)) +--end +--startAction(ACTION_BUTTON_START) +--print("input.raw:", input.raw()) +--input.do_button_action("gamepad-1-A", 1, 1) + +--endAllActions() +--startAction(ACTION_BUTTON_SELECT) +--endAllActions() +--startAction(ACTION_BUTTON_A) +--endAllActions() +--startAction(ACTION_BUTTON_B) +--endAllActions() +--startAction(ACTION_BUTTON_X) +--endAllActions() +--startAction(ACTION_BUTTON_Y) +--endAllActions() +--exit() +--movie.to_rewind("lvl1.lsmv") +actions.endAllActions() +util.loadRandomTrainingSaveState() +util.setGameSpeedToVeryFast() +states.fillWithEmptyStates() +gui.repaint()