forked from aleju/mario-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
4,573 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
require 'torch' | ||
require 'nn' | ||
require 'nngraph' | ||
require 'layers.GaussianCriterion' | ||
require 'layers.KLDCriterion' | ||
require 'layers.Sampler' | ||
|
||
local VAE = {} | ||
VAE.continuous = false | ||
|
||
function VAE.createVAE() | ||
local input_size = IMG_DIMENSIONS_AE[1] * IMG_DIMENSIONS_AE[2] * IMG_DIMENSIONS_AE[3] | ||
local hidden_layer_size = 1024 | ||
local latent_variable_size = 512 | ||
|
||
local encoder = VAE.get_encoder(input_size, hidden_layer_size, latent_variable_size) | ||
local decoder = VAE.get_decoder(input_size, hidden_layer_size, latent_variable_size, VAE.continuous) | ||
|
||
local input = nn.Identity()() | ||
local mean, log_var = encoder(input):split(2) | ||
local z = nn.Sampler()({mean, log_var}) | ||
|
||
local reconstruction = decoder(z) | ||
local model = nn.gModule({input},{reconstruction, mean, log_var}) | ||
local criterion_reconstruction = nn.BCECriterion() | ||
criterion_reconstruction.sizeAverage = false | ||
|
||
local criterion_latent = nn.KLDCriterion() | ||
|
||
local parameters, gradients = model:getParameters() | ||
|
||
return model, criterion_latent, criterion_reconstruction, parameters, gradients | ||
end | ||
|
||
|
||
function VAE.get_encoder(input_size, hidden_layer_size, latent_variable_size) | ||
-- The Encoder | ||
local encoder = nn.Sequential() | ||
if GPU then | ||
encoder:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor', true, true)) | ||
end | ||
encoder:add(nn.SpatialConvolution(IMG_DIMENSIONS_AE[1], 8, 5, 5, 2, 2, (5-1)/2, (5-1)/2)) | ||
encoder:add(nn.SpatialBatchNormalization(8)) | ||
encoder:add(nn.LeakyReLU(0.2, true)) | ||
encoder:add(nn.SpatialConvolution(8, 16, 5, 5, 2, 2, (5-1)/2, (5-1)/2)) | ||
encoder:add(nn.SpatialBatchNormalization(16)) | ||
encoder:add(nn.LeakyReLU(0.2, true)) | ||
encoder:add(nn.SpatialConvolution(16, 32, 5, 5, 2, 2, (5-1)/2, (5-1)/2)) | ||
encoder:add(nn.SpatialBatchNormalization(32)) | ||
encoder:add(nn.LeakyReLU(0.2, true)) | ||
encoder:add(nn.SpatialConvolution(32, 64, 5, 5, 2, 2, (5-1)/2, (5-1)/2)) | ||
encoder:add(nn.SpatialBatchNormalization(64)) | ||
encoder:add(nn.LeakyReLU(0.2, true)) | ||
--encoder:add(nn.Reshape(input_size)) | ||
local outSize = 64 * IMG_DIMENSIONS_AE[2]/2/2/2/2 * IMG_DIMENSIONS_AE[3]/2/2/2/2 | ||
encoder:add(nn.Reshape(outSize)) | ||
--encoder:add(nn.Linear(input_size, hidden_layer_size)) | ||
encoder:add(nn.Linear(outSize, hidden_layer_size)) | ||
encoder:add(nn.BatchNormalization(hidden_layer_size)) | ||
encoder:add(nn.LeakyReLU(0.2, true)) | ||
|
||
--if GPU then | ||
-- encoder:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true)) | ||
--end | ||
|
||
mean_logvar = nn.ConcatTable() | ||
if GPU then | ||
mean_logvar:add(nn.Sequential():add(nn.Linear(hidden_layer_size, latent_variable_size)):add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true))) | ||
mean_logvar:add(nn.Sequential():add(nn.Linear(hidden_layer_size, latent_variable_size)):add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true))) | ||
else | ||
mean_logvar:add(nn.Linear(hidden_layer_size, latent_variable_size)) | ||
mean_logvar:add(nn.Linear(hidden_layer_size, latent_variable_size)) | ||
end | ||
|
||
encoder:add(mean_logvar) | ||
|
||
if GPU then | ||
encoder:cuda() | ||
end | ||
|
||
return encoder | ||
end | ||
|
||
function VAE.get_decoder(input_size, hidden_layer_size, latent_variable_size, continuous) | ||
--local c, h, w = unpack(IMG_DIMENSIONS) | ||
|
||
-- The Decoder | ||
local decoder = nn.Sequential() | ||
if GPU then | ||
decoder:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor', true, true)) | ||
end | ||
decoder:add(nn.Linear(latent_variable_size, hidden_layer_size)) | ||
decoder:add(nn.BatchNormalization(hidden_layer_size)) | ||
decoder:add(nn.LeakyReLU(0.2, true)) | ||
|
||
if continuous then | ||
mean_logvar = nn.ConcatTable() | ||
mean_logvar:add(nn.Linear(hidden_layer_size, input_size)) | ||
mean_logvar:add(nn.Linear(hidden_layer_size, input_size)) | ||
decoder:add(mean_logvar) | ||
else | ||
decoder:add(nn.Linear(hidden_layer_size, input_size/2/2)) | ||
decoder:add(nn.Sigmoid(true)) | ||
decoder:add(nn.Reshape(IMG_DIMENSIONS_AE[1], IMG_DIMENSIONS_AE[2]/2, IMG_DIMENSIONS_AE[3]/2)) | ||
decoder:add(nn.SpatialUpSamplingNearest(2)) | ||
--[[ | ||
local c, h, w = unpack(IMG_DIMENSIONS) | ||
decoder:add(nn.Linear(latent_variable_size, 16*h/2/2*w/2/2)) | ||
decoder:add(nn.ReLU(true)) | ||
decoder:add(nn.Reshape(16, h/2/2, w/2/2)) -- 16x32 | ||
decoder:add(nn.SpatialUpSamplingNearest(2)) -- 32x64 | ||
decoder:add(nn.SpatialConvolution(16, 32, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) | ||
decoder:add(nn.ReLU(true)) | ||
decoder:add(nn.SpatialUpSamplingNearest(2)) -- 64x128 | ||
decoder:add(nn.SpatialConvolution(32, 1, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) | ||
decoder:add(nn.Sigmoid(true)) | ||
--]] | ||
end | ||
|
||
if GPU then | ||
decoder:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true)) | ||
decoder:cuda() | ||
end | ||
|
||
return decoder | ||
end | ||
|
||
function VAE.train(inputs, model, criterionLatent, criterionReconstruction, parameters, gradParameters, optconfig, optstate) | ||
|
||
|
||
local opfunc = function(x) | ||
assert(inputs ~= nil) | ||
assert(model ~= nil) | ||
assert(criterionLatent ~= nil) | ||
assert(criterionReconstruction ~= nil) | ||
assert(parameters ~= nil) | ||
assert(gradParameters ~= nil) | ||
assert(optconfig ~= nil) | ||
assert(optstate ~= nil) | ||
|
||
if x ~= parameters then | ||
parameters:copy(x) | ||
end | ||
|
||
model:zeroGradParameters() | ||
local reconstruction, reconstruction_var, mean, log_var | ||
if VAE.continuous then | ||
reconstruction, reconstruction_var, mean, log_var = unpack(model:forward(inputs)) | ||
reconstruction = {reconstruction, reconstruction_var} | ||
else | ||
reconstruction, mean, log_var = unpack(model:forward(inputs)) | ||
end | ||
|
||
local err = criterionReconstruction:forward(reconstruction, inputs) | ||
local df_dw = criterionReconstruction:backward(reconstruction, inputs) | ||
|
||
local KLDerr = criterionLatent:forward(mean, log_var) | ||
local dKLD_dmu, dKLD_dlog_var = unpack(criterionLatent:backward(mean, log_var)) | ||
|
||
if VAE.continuous then | ||
error_grads = {df_dw[1], df_dw[2], dKLD_dmu, dKLD_dlog_var} | ||
else | ||
error_grads = {df_dw, dKLD_dmu, dKLD_dlog_var} | ||
end | ||
|
||
model:backward(inputs, error_grads) | ||
|
||
local batchlowerbound = err + KLDerr | ||
|
||
print(string.format("[BATCH AE] lowerbound=%.8f", batchlowerbound)) | ||
util.displayBatch(inputs, 10, "Training images for AE (input)") | ||
util.displayBatch(reconstruction, 11, "Training images for AE (output)") | ||
|
||
return batchlowerbound, gradParameters | ||
end | ||
|
||
local x, batchlowerbound = optim.adam(opfunc, parameters, optconfig, optstate) | ||
|
||
return batchlowerbound | ||
end | ||
|
||
return VAE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
local Action = {} | ||
Action.__index = Action | ||
|
||
function Action.new(arrowAction, buttonAction) | ||
local self = setmetatable({}, Action) | ||
self.arrow = arrowAction | ||
self.button = buttonAction | ||
return self | ||
end | ||
|
||
return Action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
local actions = {} | ||
|
||
actions.ACTION_BUTTON_B = 0 | ||
actions.ACTION_BUTTON_Y = 1 | ||
actions.ACTION_BUTTON_SELECT = 2 | ||
actions.ACTION_BUTTON_START = 3 | ||
actions.ACTION_BUTTON_UP = 4 | ||
actions.ACTION_BUTTON_DOWN = 5 | ||
actions.ACTION_BUTTON_LEFT = 6 | ||
actions.ACTION_BUTTON_RIGHT = 7 | ||
actions.ACTION_BUTTON_A = 8 | ||
actions.ACTION_BUTTON_X = 9 | ||
actions.ACTION_BUTTON_L = 10 | ||
actions.ACTION_BUTTON_R = 11 | ||
actions.ACTIONS_ALL = { | ||
actions.ACTION_BUTTON_B, actions.ACTION_BUTTON_Y, | ||
actions.ACTION_BUTTON_SELECT, actions.ACTION_BUTTON_START, | ||
actions.ACTION_BUTTON_UP, actions.ACTION_BUTTON_DOWN, | ||
actions.ACTION_BUTTON_LEFT, actions.ACTION_BUTTON_RIGHT, | ||
actions.ACTION_BUTTON_A, actions.ACTION_BUTTON_X, | ||
actions.ACTION_BUTTON_L, actions.ACTION_BUTTON_R | ||
} | ||
--[[actions.ACTIONS_NETWORK = { | ||
actions.ACTION_BUTTON_B, --actions.ACTION_BUTTON_Y, | ||
actions.ACTION_BUTTON_LEFT, actions.ACTION_BUTTON_RIGHT, | ||
actions.ACTION_BUTTON_A, --actions.ACTION_BUTTON_X | ||
}--]] | ||
actions.ACTIONS_NETWORK = { | ||
actions.ACTION_BUTTON_B, actions.ACTION_BUTTON_Y, | ||
actions.ACTION_BUTTON_UP, actions.ACTION_BUTTON_DOWN, | ||
actions.ACTION_BUTTON_LEFT, actions.ACTION_BUTTON_RIGHT, | ||
actions.ACTION_BUTTON_A, actions.ACTION_BUTTON_X | ||
} | ||
actions.ACTIONS_ARROWS = { | ||
actions.ACTION_BUTTON_UP,actions.ACTION_BUTTON_DOWN, | ||
actions.ACTION_BUTTON_LEFT, actions.ACTION_BUTTON_RIGHT | ||
} | ||
actions.ACTIONS_BUTTONS = { | ||
actions.ACTION_BUTTON_B, actions.ACTION_BUTTON_Y, | ||
--actions.ACTION_BUTTON_SELECT, actions.ACTION_BUTTON_START, | ||
actions.ACTION_BUTTON_A, actions.ACTION_BUTTON_X, | ||
--actions.ACTION_BUTTON_L, actions.ACTION_BUTTON_R | ||
} | ||
actions.ACTION_TO_BUTTON_NAME = {} | ||
actions.ACTION_TO_BUTTON_NAME[0] = "gamepad-1-B" | ||
actions.ACTION_TO_BUTTON_NAME[1] = "gamepad-1-Y" | ||
actions.ACTION_TO_BUTTON_NAME[2] = "gamepad-1-select" | ||
actions.ACTION_TO_BUTTON_NAME[3] = "gamepad-1-start" | ||
actions.ACTION_TO_BUTTON_NAME[4] = "gamepad-1-up" | ||
actions.ACTION_TO_BUTTON_NAME[5] = "gamepad-1-down" | ||
actions.ACTION_TO_BUTTON_NAME[6] = "gamepad-1-left" | ||
actions.ACTION_TO_BUTTON_NAME[7] = "gamepad-1-right" | ||
actions.ACTION_TO_BUTTON_NAME[8] = "gamepad-1-A" | ||
actions.ACTION_TO_BUTTON_NAME[9] = "gamepad-1-X" | ||
actions.ACTION_TO_BUTTON_NAME[10] = "gamepad-1-L" | ||
actions.ACTION_TO_BUTTON_NAME[11] = "gamepad-1-R" | ||
|
||
function actions.isArrowsActionIdx(actionIdx) | ||
for i=1,#actions.ACTIONS_ARROWS do | ||
if actionIdx == actions.ACTIONS_ARROWS[i] then | ||
return true | ||
end | ||
end | ||
return false | ||
end | ||
|
||
function actions.isButtonsActionIdx(actionIdx) | ||
for i=1,#actions.ACTIONS_BUTTONS do | ||
if actionIdx == actions.ACTIONS_BUTTONS[i] then | ||
return true | ||
end | ||
end | ||
return false | ||
end | ||
|
||
function actions.createRandomAction() | ||
local arrow = actions.ACTIONS_ARROWS[math.random(#actions.ACTIONS_ARROWS)] | ||
local button = actions.ACTIONS_BUTTONS[math.random(#actions.ACTIONS_BUTTONS)] | ||
return Action.new(arrow, button) | ||
end | ||
|
||
function actions.endAllActions() | ||
--local lcid = 1 | ||
--local port, controller = input.lcid_to_pcid2(lcid) | ||
--local controller = 0 | ||
--input.set2(port, controller, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) | ||
for i=1,#actions.ACTIONS_ALL do | ||
local newstate = 0 -- 1 = pressed, 0 = released | ||
local mode = 3 -- 1 = autohold, 2 = framehold, others = press/release | ||
input.do_button_action(actions.ACTION_TO_BUTTON_NAME[actions.ACTIONS_ALL[i]], newstate, mode) | ||
end | ||
end | ||
|
||
--[[ | ||
function endAction(action) | ||
local lcid = 1 | ||
local port, controller = input.lcid_to_pcid2(lcid) | ||
--local controller = 0 | ||
local value = 0 -- 0 = release, 1 = press | ||
input.set2(port, controller, action, value) | ||
end | ||
--]] | ||
|
||
function actions.startAction(action) | ||
assert(action ~= nil) | ||
--for lcid=1,8 do | ||
--print(port, controller) | ||
--local controller = 0 | ||
--local value = 1 -- 0 = release, 1 = press | ||
--input.set2(port, controller, action, value) | ||
--end | ||
--setJoypad2({action}) | ||
--print("Starting action!", action) | ||
local newstate = 1 -- 1 = pressed, 0 = released | ||
local mode = 3 -- 1 = autohold, 2 = framehold, others = press/release | ||
--if action == ACTION_BUTTON_B or action == ACTION_BUTTON_A then | ||
-- mode = 2 | ||
--end | ||
local arrowAction = actions.ACTION_TO_BUTTON_NAME[action.arrow] | ||
local buttonAction = actions.ACTION_TO_BUTTON_NAME[action.button] | ||
assert(arrowAction ~= nil) | ||
assert(buttonAction ~= nil) | ||
input.do_button_action(arrowAction, newstate, mode) | ||
input.do_button_action(buttonAction, newstate, mode) | ||
end | ||
|
||
function actions.setJoypad(actions) | ||
print("set joypad") | ||
local lcid = 1 | ||
local port, controller = input.lcid_to_pcid2(lcid) | ||
local value = 1 -- 0 = release, 1 = press | ||
--input.set2(port, controller, ACTION_BUTTON_A, value) | ||
input.set2(port, controller, 0, 0) | ||
--for i=0,32000 do | ||
-- input.set2(port, controller, i, 1) | ||
--end | ||
end | ||
|
||
function actions.setJoypad2(actions) | ||
local lcid = 1 | ||
local port, controller = input.lcid_to_pcid2(lcid) | ||
--[[ | ||
local table = { | ||
B = false, Y = false, select = false, start = false, | ||
up = false, down = false, left = false, right = false, | ||
A = false, X = false, | ||
L = false, R = false | ||
} | ||
for i=1,#actions do | ||
local action = actions[i] | ||
if action == ACTION_BUTTON_B then table.B = true end | ||
if action == ACTION_BUTTON_Y then table.Y = true end | ||
if action == ACTION_BUTTON_SELECT then table.select = true end | ||
if action == ACTION_BUTTON_START then table.start = true end | ||
if action == ACTION_BUTTON_UP then table.up = true end | ||
if action == ACTION_BUTTON_DOWN then table.down = true end | ||
if action == ACTION_BUTTON_LEFT then table.left = true end | ||
if action == ACTION_BUTTON_RIGHT then table.right = true end | ||
if action == ACTION_BUTTON_A then table.A = true end | ||
if action == ACTION_BUTTON_X then table.X = true end | ||
if action == ACTION_BUTTON_L then table.L = true end | ||
if action == ACTION_BUTTON_R then table.R = true end | ||
end | ||
--]] | ||
local table = {} | ||
table["P1 B"] = true | ||
table["P1 Y"] = true | ||
table["P1 select"] = true | ||
table["P1 start"] = true | ||
table["P1 up"] = true | ||
table["P1 down"] = true | ||
table["P1 left"] = true | ||
table["P1 right"] = true | ||
table["P1 A"] = true | ||
table["P1 X"] = true | ||
table["P1 L"] = true | ||
table["P1 R"] = true | ||
local table2 = {} | ||
table2["B"] = true | ||
table2["Y"] = false | ||
table2["select"] = false | ||
table2["start"] = false | ||
table2["up"] = false | ||
table2["down"] = false | ||
table2["left"] = false | ||
table2["right"] = false | ||
table2["A"] = true | ||
table2["X"] = false | ||
table2["L"] = false | ||
table2["R"] = false | ||
local table3 = {} | ||
for i=1,12 do | ||
table3[i] = false | ||
if i==1 and math.random()<0.1 then table3[i] = true end | ||
end | ||
print("Sending to joyset...", table) | ||
--for i=1,1 do | ||
input.joyset(1, table3) | ||
--end | ||
end | ||
|
||
return actions |
Oops, something went wrong.