Skip to content

Commit

Permalink
fix input_nc bug
Browse files Browse the repository at this point in the history
  • Loading branch information
junyanz committed Apr 7, 2017
1 parent 31b7884 commit bf48e72
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
4 changes: 2 additions & 2 deletions models/architectures.lua
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ function defineG_resnet_6blocks(input_nc, output_nc, ngf)
local f = 7
local p = (f - 1) / 2
local data = -nn.Identity()
local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(3, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
Expand All @@ -258,7 +258,7 @@ function defineG_resnet_9blocks(input_nc, output_nc, ngf)
local f = 7
local p = (f - 1) / 2
local data = -nn.Identity()
local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(3, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
Expand Down
28 changes: 17 additions & 11 deletions util/InstanceNormalization.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
require 'nn'

_ = [[
An implementation for https://arxiv.org/abs/1607.08022
--[[
Implements instance normalization as described in the paper
Instance Normalization: The Missing Ingredient for Fast Stylization
Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky
https://arxiv.org/abs/1607.08022
This implementation is based on
https://github.com/DmitryUlyanov/texture_nets
]]

local InstanceNormalization, parent = torch.class('nn.InstanceNormalization', 'nn.Module')
Expand All @@ -19,24 +25,24 @@ function InstanceNormalization:__init(nOutput, eps, momentum, affine)
else
self.affine = true
end

self.nOutput = nOutput
self.prev_batch_size = -1

if self.affine then
if self.affine then
self.weight = torch.Tensor(nOutput):uniform()
self.bias = torch.Tensor(nOutput):zero()
self.gradWeight = torch.Tensor(nOutput)
self.gradBias = torch.Tensor(nOutput)
end
end
end

function InstanceNormalization:updateOutput(input)
self.output = self.output or input.new()
assert(input:size(2) == self.nOutput)

local batch_size = input:size(1)

if batch_size ~= self.prev_batch_size or (self.bn and self:type() ~= self.bn:type()) then
self.bn = nn.SpatialBatchNormalization(input:size(1)*input:size(2), self.eps, self.momentum, self.affine)
self.bn:type(self:type())
Expand All @@ -58,7 +64,7 @@ function InstanceNormalization:updateOutput(input)

local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
self.output = self.bn:forward(input_1obj):viewAs(input)

return self.output
end

Expand All @@ -67,9 +73,9 @@ function InstanceNormalization:updateGradInput(input, gradOutput)

assert(self.bn)

local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
local gradOutput_1obj = gradOutput:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
local gradOutput_1obj = gradOutput:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))

if self.affine then
self.bn.gradWeight:zero()
self.bn.gradBias:zero()
Expand All @@ -87,7 +93,7 @@ end
function InstanceNormalization:clearState()
self.output = self.output.new()
self.gradInput = self.gradInput.new()

if self.bn then
self.bn:clearState()
end
Expand Down

0 comments on commit bf48e72

Please sign in to comment.