forked from junyanz/CycleGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcontent_loss.lua
110 lines (99 loc) · 3.34 KB
/
content_loss.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
require 'torch'
require 'nn'
local content = {}
function content.defineVGG(content_layer)
local contentFunc = nn.Sequential()
require 'loadcaffe'
require 'util/VGG_preprocess'
cnn = loadcaffe.load('../models/vgg.prototxt', '../models/vgg.caffemodel', 'cudnn')
contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224}))
contentFunc:add(nn.VGG_postprocess())
for i = 1, #cnn do
local layer = cnn:get(i):clone()
local name = layer.name
local layer_type = torch.type(layer)
contentFunc:add(layer)
if name == content_layer then
print("Setting up content layer: ", layer.name)
break
end
end
cnn = nil
collectgarbage()
print(contentFunc)
return contentFunc
end
function content.defineAlexNet(content_layer)
local contentFunc = nn.Sequential()
require 'loadcaffe'
require 'util/VGG_preprocess'
cnn = loadcaffe.load('../models/alexnet.prototxt', '../models/alexnet.caffemodel', 'cudnn')
contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224}))
contentFunc:add(nn.VGG_postprocess())
for i = 1, #cnn do
local layer = cnn:get(i):clone()
local name = layer.name
local layer_type = torch.type(layer)
contentFunc:add(layer)
if name == content_layer then
print("Setting up content layer: ", layer.name)
break
end
end
cnn = nil
collectgarbage()
print(contentFunc)
return contentFunc
end
function content.defineHED()
local hed = nn.Sequential()
require 'loadcaffe'
-- require 'caffegraph'
require 'util/VGG_preprocess'
cnn = loadcaffe.load('../models/hed.prototxt', '../models/hed.caffemodel', 'cudnn')
hed:add(nn.SpatialUpSamplingBilinear({oheight=500, owidth=500}))
hed:add(nn.VGG_postprocess())
hed:add(cnn)
return hed
end
function content.defineVGGClf()
local clf = nn.Sequential()
require 'loadcaffe'
require 'util/VGG_preprocess'
cnn = loadcaffe.load('../models/vgg.prototxt', '../models/vgg.caffemodel', 'cudnn')
clf:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224}))
clf:add(nn.VGG_postprocess())
clf:add(cnn)
return clf
end
function content.defineContent(content_loss, layer_name)
-- print('content_loss_define', content_loss)
if content_loss == 'pixel' or content_loss == 'none' then
return nil
elseif content_loss == 'vgg' then
return content.defineVGG(layer_name)
else
print("unsupported content loss")
return nil
end
end
function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight)
if loss_type == 'none' then
local errCont = 0.0
local df_d_content = torch.zeros(fake_target:size())
return errCont, df_d_content
elseif loss_type == 'pixel' then
local errCont = criterionContent:forward(fake_target, real_source) * weight
local df_do_content = criterionContent:backward(fake_target, real_source)*weight
return errCont, df_do_content
elseif loss_type == 'vgg' then
local f_fake = contentFunc:forward(fake_target):clone()
local f_real = contentFunc:forward(real_source):clone()
local errCont = criterionContent:forward(f_fake, f_real) * weight
local df_do_tmp = criterionContent:backward(f_fake, f_real) * weight
local df_do_content = contentFunc:updateGradInput(fake_target, df_do_tmp)--:mul(weight)
return errCont, df_do_content
else error("unsupported content loss")
end
end
return content