-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathdonkey_folderC.lua
267 lines (215 loc) · 7.58 KB
/
donkey_folderC.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
--[[
This data loader is a modified version of the one from dcgan.torch
(see https://github.com/soumith/dcgan.torch/blob/master/data/donkey_folder.lua).
Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
Copyright (c) 2015-present, Facebook, Inc.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. An additional grant
of patent rights can be found in the PATENTS file in the same directory.
]]--
require 'image'
paths.dofile('dataset.lua')
-- This file contains the data-loading logic and details.
-- It is run by each data-loader thread.
------------------------------------------
-------- COMMON CACHES and PATHS
-- Check for existence of opt.data
if opt.DATA_ROOT then
opt.data=paths.concat(opt.DATA_ROOT, opt.phase)
else
opt.data=paths.concat(os.getenv('DATA_ROOT'), opt.phase)
end
if not paths.dirp(opt.data) then
error('Did not find directory: ' .. opt.data)
end
-- a cache file of the training metadata (if doesnt exist, will be created)
local cache = "cache"
local cache_prefix = opt.data:gsub('/', '_')
os.execute('mkdir -p cache')
local trainCache = paths.concat(cache, cache_prefix .. '_trainCache.t7')
--------------------------------------------------------------------------------------------
local input_nc = opt.input_nc -- input channels
local output_nc = opt.output_nc
local loadSize = {input_nc/3, opt.loadSize}
local sampleSize = {input_nc/3, opt.fineSize}
local preprocessAandBC = function(imA, imB)
imA = image.scale(imA, loadSize[2], loadSize[2])
imB = image.scale(imB, loadSize[2], loadSize[2])
local perm = torch.LongTensor{3, 2, 1}
-- imA = imA:index(1, perm)--:mul(256.0): brg, rgb
-- imA = imA:mul(2):add(-1) --这里把范围从-1到1 改为0到1
-- imB = imB:index(1, perm)
-- imB = imB:mul(2):add(-1)
-- imC = imC:index(1, perm)
-- imC = imC:mul(2):add(-1)
-- assert(imA:max()<=1,"A: badly scaled inputs")
-- assert(imA:min()>=-1,"A: badly scaled inputs")
-- assert(imB:max()<=1,"B: badly scaled inputs")
-- assert(imB:min()>=-1,"B: badly scaled inputs")
-- assert(imC:max()<=1,"C: badly scaled inputs")
-- assert(imC:min()>=-1,"C: badly scaled inputs")
--这里把范围从-1到1 改为0到1
imA = imA:index(1, perm)--:mul(256.0): brg, rgb
-- imA = imA:mul(2):add(-1)
imB = imB:index(1, perm)
-- imB = imB:mul(2):add(-1)
assert(imA:max()<=1,"A: badly scaled inputs")
assert(imA:min()>=0,"A: badly scaled inputs")
assert(imB:max()<=1,"B: badly scaled inputs")
assert(imB:min()>=0,"B: badly scaled inputs")
local oW = sampleSize[2]
local oH = sampleSize[2]
local iH = imA:size(2)
local iW = imA:size(3)
if iH~=oH then
h1 = math.ceil(torch.uniform(1e-2, iH-oH))
end
if iW~=oW then
w1 = math.ceil(torch.uniform(1e-2, iW-oW))
end
if iH ~= oH or iW ~= oW then
imA = image.crop(imA, w1, h1, w1 + oW, h1 + oH)
imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH)
end
local flip_flag=0
if opt.flip == 1 and torch.uniform() > 0.5 then --
imA = image.hflip(imA)
imB = image.hflip(imB)
flip_flag=1
end
return imA, imB,flip_flag
end
local function loadImageChannel(path)
local input = image.load(path, 3, 'float')
input = image.scale(input, loadSize[2], loadSize[2])
local oW = sampleSize[2]
local oH = sampleSize[2]
local iH = input:size(2)
local iW = input:size(3)
if iH~=oH then
h1 = math.ceil(torch.uniform(1e-2, iH-oH))
end
if iW~=oW then
w1 = math.ceil(torch.uniform(1e-2, iW-oW))
end
if iH ~= oH or iW ~= oW then
input = image.crop(input, w1, h1, w1 + oW, h1 + oH)
end
if opt.flip == 1 and torch.uniform() > 0.5 then
input = image.hflip(input)
end
-- print(input:mean(), input:min(), input:max())
local input_lab = image.rgb2lab(input)
-- print(input_lab:size())
-- os.exit()
local imA = input_lab[{{1}, {}, {} }]:div(50.0) - 1.0
local imB = input_lab[{{2,3},{},{}}]:div(110.0)
local imAB = torch.cat(imA, imB, 1)
assert(imAB:max()<=1,"A: badly scaled inputs")
assert(imAB:min()>=-1,"A: badly scaled inputs")
return imAB
end
--local function loadImage
local function loadImage(path)
local input = image.load(path, 3, 'float')
local h = input:size(2)
local w = input:size(3)
local imA = image.crop(input, 0, 0, w/2, h)
local imB = image.crop(input, w/2, 0, w, h)
return imA, imB
end
local function loadImageInpaint(path)
local imB = image.load(path, 3, 'float')
imB = image.scale(imB, loadSize[2], loadSize[2])
local perm = torch.LongTensor{3, 2, 1}
imB = imB:index(1, perm)--:mul(256.0): brg, rgb
imB = imB:mul(2):add(-1)
assert(imB:max()<=1,"A: badly scaled inputs")
assert(imB:min()>=-1,"A: badly scaled inputs")
local oW = sampleSize[2]
local oH = sampleSize[2]
local iH = imB:size(2)
local iW = imB:size(3)
if iH~=oH then
h1 = math.ceil(torch.uniform(1e-2, iH-oH))
end
if iW~=oW then
w1 = math.ceil(torch.uniform(1e-2, iW-oW))
end
if iH ~= oH or iW ~= oW then
imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH)
end
local imA = imB:clone()
imA[{{},{1 + oH/4, oH/2 + oH/4},{1 + oW/4, oW/2 + oW/4}}] = 1.0
if opt.flip == 1 and torch.uniform() > 0.5 then
imA = image.hflip(imA)
imB = image.hflip(imB)
end
imAB = torch.cat(imA, imB, 1)
return imAB
end
-- channel-wise mean and std. Calculate or load them from disk later in the script.
local mean,std
--------------------------------------------------------------------------------
-- Hooks that are used for each image that is loaded
-- function to load the image, jitter it appropriately (random crops etc.)
local trainHook = function(self, path)
collectgarbage()
local flip_flag
if opt.preprocess == 'regular' then
-- print('process regular')
local imA, imB = loadImage(path)
imA, imB,flip_flag = preprocessAandBC(imA, imB)
imAB = torch.cat(imA, imB, 1)
--print('image C size')
--print(imAB:size())
end
if opt.preprocess == 'colorization' then
-- print('process colorization')
imAB = loadImageChannel(path)
end
if opt.preprocess == 'inpaint' then
-- print('process inpaint')
imAB = loadImageInpaint(path)
end
-- print('image AB size')
-- print(imAB:size())
return imAB,flip_flag
end
--------------------------------------
-- trainLoader
print('trainCache', trainCache)
--if paths.filep(trainCache) then
-- print('Loading train metadata from cache')
-- trainLoader = torch.load(trainCache)
-- trainLoader.sampleHookTrain = trainHook
-- trainLoader.loadSize = {input_nc, opt.loadSize, opt.loadSize}
-- trainLoader.sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]}
-- trainLoader.serial_batches = opt.serial_batches
-- trainLoader.split = 100
--else
print('Creating train metadata')
-- print(opt.data)
print('serial batch:, ', opt.serial_batches)
trainLoader = dataLoader{
paths = {opt.data},
loadSize = {input_nc, loadSize[2], loadSize[2]},
sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]},
split = 100,
serial_batches = opt.serial_batches,
verbose = true
}
-- print('finish')
--torch.save(trainCache, trainLoader)
--print('saved metadata cache at', trainCache)
trainLoader.sampleHookTrain = trainHook
--end
collectgarbage()
-- do some sanity checks on trainLoader
do
local class = trainLoader.imageClass
local nClasses = #trainLoader.classes
assert(class:max() <= nClasses, "class logic has error")
assert(class:min() >= 1, "class logic has error")
end