diff --git a/models/cycle_gan_model.lua b/models/cycle_gan_model.lua index 66c6644..7397b03 100644 --- a/models/cycle_gan_model.lua +++ b/models/cycle_gan_model.lua @@ -2,7 +2,6 @@ local class = require 'class' require 'models.base_model' require 'models.architectures' require 'util.image_pool' -local optnet = require 'optnet' util = paths.dofile('../util/util.lua') CycleGANModel = class('CycleGANModel', 'BaseModel') @@ -51,9 +50,12 @@ function CycleGANModel:Initialize(opt) netG_B = util.load_test_model('G_B', opt) --setup optnet to save a little bit of memory - local sample_input = torch.randn(1, opt.input_nc, 2, 2) - optnet.optimizeMemory(netG_A, sample_input, {inplace=true, reuseBuffers=true}) - optnet.optimizeMemory(netG_B, sample_input, {inplace=true, reuseBuffers=true}) + if opt.use_optnet == 1 then + local sample_input = torch.randn(1, opt.input_nc, 2, 2) + local optnet = require 'optnet' + optnet.optimizeMemory(netG_A, sample_input, {inplace=true, reuseBuffers=true}) + optnet.optimizeMemory(netG_B, sample_input, {inplace=true, reuseBuffers=true}) + end else netG_A = util.load_model('G_A', opt) netG_B = util.load_model('G_B', opt) diff --git a/models/one_direction_test_model.lua b/models/one_direction_test_model.lua index 13c593d..b06c186 100644 --- a/models/one_direction_test_model.lua +++ b/models/one_direction_test_model.lua @@ -2,7 +2,6 @@ local class = require 'class' require 'models.base_model' require 'models.architectures' require 'util.image_pool' -local optnet = require 'optnet' util = paths.dofile('../util/util.lua') OneDirectionTestModel = class('OneDirectionTestModel', 'BaseModel') @@ -25,8 +24,11 @@ function OneDirectionTestModel:Initialize(opt) self.netG_A = util.load_test_model('G', opt) -- setup optnet to save a bit of memory - local sample_input = torch.randn(1, opt.input_nc, 2, 2) - optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true}) + if opt.use_optnet == 1 then + local optnet = require 'optnet' + local sample_input = torch.randn(1, opt.input_nc, 2, 2) + optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true}) + end self:RefreshParameters() diff --git a/options.lua b/options.lua index 41bcd71..a4c9315 100644 --- a/options.lua +++ b/options.lua @@ -49,6 +49,7 @@ local opt_train = { pool_size = 50, -- the size of image buffer that stores previously generated images resize_or_crop = 'resize_and_crop', -- resizing/cropping strategy identity = 0, -- use identity mapping. Setting opt.identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set opt.identity = 0.1 + use_optnet = 0, -- use optnet to save GPU memory during test } -- options for test