forked from junyanz/CycleGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharchitectures.lua
384 lines (338 loc) · 17.8 KB
/
architectures.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
require 'nngraph'
----------------------------------------------------------------------------
local function weights_init(m)
local name = torch.type(m)
if name:find('Convolution') then
m.weight:normal(0.0, 0.02)
m.bias:fill(0)
elseif name:find('Normalization') then
if m.weight then m.weight:normal(1.0, 0.02) end
if m.bias then m.bias:fill(0) end
end
end
normalization = nil
function set_normalization(norm)
if norm == 'instance' then
require 'util.InstanceNormalization'
print('use InstanceNormalization')
normalization = nn.InstanceNormalization
elseif norm == 'batch' then
print('use SpatialBatchNormalization')
normalization = nn.SpatialBatchNormalization
end
end
function defineG(input_nc, output_nc, ngf, which_model_netG, nz, arch)
local netG = nil
if which_model_netG == "encoder_decoder" then netG = defineG_encoder_decoder(input_nc, output_nc, ngf)
elseif which_model_netG == "unet128" then netG = defineG_unet128(input_nc, output_nc, ngf)
elseif which_model_netG == "unet256" then netG = defineG_unet256(input_nc, output_nc, ngf)
elseif which_model_netG == "resnet_6blocks" then netG = defineG_resnet_6blocks(input_nc, output_nc, ngf)
elseif which_model_netG == "resnet_9blocks" then netG = defineG_resnet_9blocks(input_nc, output_nc, ngf)
else error("unsupported netG model")
end
netG:apply(weights_init)
return netG
end
function defineD(input_nc, ndf, which_model_netD, n_layers_D, use_sigmoid)
local netD = nil
if which_model_netD == "basic" then netD = defineD_basic(input_nc, ndf, use_sigmoid)
elseif which_model_netD == "imageGAN" then netD = defineD_imageGAN(input_nc, ndf, use_sigmoid)
elseif which_model_netD == "n_layers" then netD = defineD_n_layers(input_nc, ndf, n_layers_D, use_sigmoid)
else error("unsupported netD model")
end
netD:apply(weights_init)
return netD
end
function defineG_encoder_decoder(input_nc, output_nc, ngf)
-- input is (nc) x 256 x 256
local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
-- input is (ngf) x 128 x 128
local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 2) x 64 x 64
local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 4) x 32 x 32
local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 16 x 16
local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 8 x 8
local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 4 x 4
local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 2 x 2
local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8)
-- input is (ngf * 8) x 1 x 1
local d1 = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 2 x 2
local d2 = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 4 x 4
local d3 = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 8 x 8
local d4 = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 16 x 16
local d5 = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 4) x 32 x 32
local d6 = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 2) x 64 x 64
local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
-- input is (ngf) x128 x 128
local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf, output_nc, 4, 4, 2, 2, 1, 1)
-- input is (nc) x 256 x 256
local o1 = d8 - nn.Tanh()
local netG = nn.gModule({e1},{o1})
return netG
end
function defineG_unet128(input_nc, output_nc, ngf)
local netG = nil
-- input is (nc) x 128 x 128
local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
-- input is (ngf) x 64 x 64
local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 2) x 32 x 32
local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 4) x 16 x 16
local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 8 x 8
local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 4 x 4
local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 2 x 2
local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8)
-- input is (ngf * 8) x 1 x 1
local d1_ = e7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 2 x 2
local d1 = {d1_,e6} - nn.JoinTable(2)
local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 4 x 4
local d2 = {d2_,e5} - nn.JoinTable(2)
local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 8 x 8
local d3 = {d3_,e4} - nn.JoinTable(2)
local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 8) x 16 x 16
local d4 = {d4_,e3} - nn.JoinTable(2)
local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 4) x 32 x 32
local d5 = {d5_,e2} - nn.JoinTable(2)
local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
-- input is (ngf * 2) x 64 x 64
local d6 = {d6_,e1} - nn.JoinTable(2)
local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
-- input is (nc) x 128 x 128
local o1 = d7 - nn.Tanh()
local netG = nn.gModule({e1},{o1})
return netG
end
function defineG_unet256(input_nc, output_nc, ngf)
local netG = nil
-- input is (nc) x 256 x 256
local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
-- input is (ngf) x 128 x 128
local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 2) x 64 x 64
local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 4) x 32 x 32
local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 16 x 16
local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 8 x 8
local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 4 x 4
local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 2 x 2
local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- - normalization(ngf * 8)
-- input is (ngf * 8) x 1 x 1
local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 2 x 2
local d1 = {d1_,e7} - nn.JoinTable(2)
local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 4 x 4
local d2 = {d2_,e6} - nn.JoinTable(2)
local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 8 x 8
local d3 = {d3_,e5} - nn.JoinTable(2)
local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 16 x 16
local d4 = {d4_,e4} - nn.JoinTable(2)
local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 4) x 32 x 32
local d5 = {d5_,e3} - nn.JoinTable(2)
local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 2) x 64 x 64
local d6 = {d6_,e2} - nn.JoinTable(2)
local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
-- input is (ngf) x128 x 128
local d7 = {d7_,e1} - nn.JoinTable(2)
local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
-- input is (nc) x 256 x 256
local o1 = d8 - nn.Tanh()
local netG = nn.gModule({e1},{o1})
return netG
end
--------------------------------------------------------------------------------
-- Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/
--------------------------------------------------------------------------------
local function build_conv_block(dim, padding_type)
local conv_block = nn.Sequential()
local p = 0
if padding_type == 'reflect' then
conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1))
elseif padding_type == 'replicate' then
conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1))
elseif padding_type == 'zero' then
p = 1
end
conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p))
conv_block:add(normalization(dim))
conv_block:add(nn.ReLU(true))
if padding_type == 'reflect' then
conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1))
elseif padding_type == 'replicate' then
conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1))
end
conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p))
conv_block:add(normalization(dim))
return conv_block
end
local function build_res_block(dim, padding_type)
local conv_block = build_conv_block(dim, padding_type)
local res_block = nn.Sequential()
local concat = nn.ConcatTable()
concat:add(conv_block)
concat:add(nn.Identity())
res_block:add(concat):add(nn.CAddTable())
return res_block
end
function defineG_resnet_6blocks(input_nc, output_nc, ngf)
padding_type = 'reflect'
local ks = 3
local netG = nil
local f = 7
local p = (f - 1) / 2
local data = -nn.Identity()
local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
- build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
netG = nn.gModule({data},{d4})
return netG
end
function defineG_resnet_9blocks(input_nc, output_nc, ngf)
padding_type = 'reflect'
local ks = 3
local netG = nil
local f = 7
local p = (f - 1) / 2
local data = -nn.Identity()
local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
- build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
- build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
netG = nn.gModule({data},{d4})
return netG
end
function defineD_imageGAN(input_nc, ndf, use_sigmoid)
local netD = nn.Sequential()
-- input is (nc) x 256 x 256
netD:add(nn.SpatialConvolution(input_nc, ndf, 4, 4, 2, 2, 1, 1))
netD:add(nn.LeakyReLU(0.2, true))
-- state size: (ndf) x 128 x 128
netD:add(nn.SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1))
netD:add(nn.SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*2) x 64 x 64
netD:add(nn.SpatialConvolution(ndf * 2, ndf*4, 4, 4, 2, 2, 1, 1))
netD:add(nn.SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*4) x 32 x 32
netD:add(nn.SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1))
netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*8) x 16 x 16
netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1))
netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*8) x 8 x 8
netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1))
netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*8) x 4 x 4
netD:add(nn.SpatialConvolution(ndf * 8, 1, 4, 4, 2, 2, 1, 1))
-- state size: 1 x 1 x 1
if use_sigmoid then
netD:add(nn.Sigmoid())
end
return netD
end
function defineD_basic(input_nc, ndf, use_sigmoid)
n_layers = 3
return defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid)
end
-- rf=1
function defineD_pixelGAN(input_nc, ndf, use_sigmoid)
local netD = nn.Sequential()
-- input is (nc) x 256 x 256
netD:add(nn.SpatialConvolution(input_nc, ndf, 1, 1, 1, 1, 0, 0))
netD:add(nn.LeakyReLU(0.2, true))
-- state size: (ndf) x 256 x 256
netD:add(nn.SpatialConvolution(ndf, ndf * 2, 1, 1, 1, 1, 0, 0))
netD:add(normalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*2) x 256 x 256
netD:add(nn.SpatialConvolution(ndf * 2, 1, 1, 1, 1, 1, 0, 0))
-- state size: 1 x 256 x 256
if use_sigmoid then
netD:add(nn.Sigmoid())
-- state size: 1 x 30 x 30
end
return netD
end
-- if n=0, then use pixelGAN (rf=1)
-- else rf is 16 if n=1
-- 34 if n=2
-- 70 if n=3
-- 142 if n=4
-- 286 if n=5
-- 574 if n=6
function defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid, kw, dropout_ratio)
if dropout_ratio == nil then
dropout_ratio = 0.0
end
if kw == nil then
kw = 4
end
padw = math.ceil((kw-1)/2)
if n_layers==0 then
return defineD_pixelGAN(input_nc, ndf, use_sigmoid)
else
local netD = nn.Sequential()
-- input is (nc) x 256 x 256
-- print('input_nc', input_nc)
netD:add(nn.SpatialConvolution(input_nc, ndf, kw, kw, 2, 2, padw, padw))
netD:add(nn.LeakyReLU(0.2, true))
local nf_mult = 1
local nf_mult_prev = 1
for n = 1, n_layers-1 do
nf_mult_prev = nf_mult
nf_mult = math.min(2^n,8)
netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 2, 2, padw,padw))
netD:add(normalization(ndf * nf_mult)):add(nn.Dropout(dropout_ratio))
netD:add(nn.LeakyReLU(0.2, true))
end
-- state size: (ndf*M) x N x N
nf_mult_prev = nf_mult
nf_mult = math.min(2^n_layers,8)
netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 1, 1, padw, padw))
netD:add(normalization(ndf * nf_mult)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*M*2) x (N-1) x (N-1)
netD:add(nn.SpatialConvolution(ndf * nf_mult, 1, kw, kw, 1, 1, padw,padw))
-- state size: 1 x (N-2) x (N-2)
if use_sigmoid then
netD:add(nn.Sigmoid())
end
-- state size: 1 x (N-2) x (N-2)
return netD
end
end