Skip to content

Commit

Permalink
add use_optnet flag (default=0)
Browse files Browse the repository at this point in the history
  • Loading branch information
junyanz committed Jul 13, 2017
1 parent 39a7f5e commit 4ea574b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
10 changes: 6 additions & 4 deletions models/cycle_gan_model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
local optnet = require 'optnet'

util = paths.dofile('../util/util.lua')
CycleGANModel = class('CycleGANModel', 'BaseModel')
Expand Down Expand Up @@ -51,9 +50,12 @@ function CycleGANModel:Initialize(opt)
netG_B = util.load_test_model('G_B', opt)

--setup optnet to save a little bit of memory
local sample_input = torch.randn(1, opt.input_nc, 2, 2)
optnet.optimizeMemory(netG_A, sample_input, {inplace=true, reuseBuffers=true})
optnet.optimizeMemory(netG_B, sample_input, {inplace=true, reuseBuffers=true})
if opt.use_optnet == 1 then
local sample_input = torch.randn(1, opt.input_nc, 2, 2)
local optnet = require 'optnet'
optnet.optimizeMemory(netG_A, sample_input, {inplace=true, reuseBuffers=true})
optnet.optimizeMemory(netG_B, sample_input, {inplace=true, reuseBuffers=true})
end
else
netG_A = util.load_model('G_A', opt)
netG_B = util.load_model('G_B', opt)
Expand Down
8 changes: 5 additions & 3 deletions models/one_direction_test_model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
local optnet = require 'optnet'

util = paths.dofile('../util/util.lua')
OneDirectionTestModel = class('OneDirectionTestModel', 'BaseModel')
Expand All @@ -25,8 +24,11 @@ function OneDirectionTestModel:Initialize(opt)
self.netG_A = util.load_test_model('G', opt)

-- setup optnet to save a bit of memory
local sample_input = torch.randn(1, opt.input_nc, 2, 2)
optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true})
if opt.use_optnet == 1 then
local optnet = require 'optnet'
local sample_input = torch.randn(1, opt.input_nc, 2, 2)
optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true})
end

self:RefreshParameters()

Expand Down
1 change: 1 addition & 0 deletions options.lua
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ local opt_train = {
pool_size = 50, -- the size of image buffer that stores previously generated images
resize_or_crop = 'resize_and_crop', -- resizing/cropping strategy
identity = 0, -- use identity mapping. Setting opt.identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set opt.identity = 0.1
use_optnet = 0, -- use optnet to save GPU memory during test
}

-- options for test
Expand Down

0 comments on commit 4ea574b

Please sign in to comment.