forked from junyanz/CycleGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathone_direction_test_model.lua
71 lines (57 loc) · 1.86 KB
/
one_direction_test_model.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
OneDirectionTestModel = class('OneDirectionTestModel', 'BaseModel')
function OneDirectionTestModel:__init(conf)
BaseModel.__init(self, conf)
conf = conf or {}
end
function OneDirectionTestModel:model_name()
return 'OneDirectionTestModel'
end
-- Defines models and networks
function OneDirectionTestModel:Initialize(opt)
-- define tensors
self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
-- load/define models
self.netG_A = util.load_test_model('G', opt)
self:RefreshParameters()
print('---------- # Learnable Parameters --------------')
print(('G_A = %d'):format(self.parametersG_A:size(1)))
print('------------------------------------------------')
end
-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function OneDirectionTestModel:Forward(input, opt)
if opt.which_direction == 'BtoA' then
input.real_A = input.real_B:clone()
end
self.real_A = input.real_A:clone()
if opt.gpu > 0 then
self.real_A = self.real_A:cuda()
end
self.fake_B = self.netG_A:forward(self.real_A):clone()
end
function OneDirectionTestModel:RefreshParameters()
self.parametersG_A, self.gradparametersG_A = nil, nil
self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters()
end
local function MakeIm3(im)
if im:size(2) == 1 then
local im3 = torch.repeatTensor(im, 1,3,1,1)
return im3
else
return im
end
end
function OneDirectionTestModel:GetCurrentVisuals(opt, size)
if not size then
size = opt.display_winsize
end
local visuals = {}
table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
return visuals
end