forked from tgyg-jegli/tf_texture_net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
InstanceNormalization.lua
116 lines (87 loc) · 3.6 KB
/
InstanceNormalization.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
--[[
Copied from https://github.com/DmitryUlyanov/texture_nets
Copyright Texture Nets Dmitry Ulyanov
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]--
require 'nn'
_ = [[
An implementation for https://arxiv.org/abs/1607.08022
]]
local InstanceNormalization, parent = torch.class('nn.InstanceNormalization', 'nn.Module')
function InstanceNormalization:__init(nOutput, eps, momentum, affine)
parent.__init(self)
self.running_mean = torch.zeros(nOutput)
self.running_var = torch.ones(nOutput)
self.eps = eps or 1e-5
self.momentum = momentum or 0.0
if affine ~= nil then
assert(type(affine) == 'boolean', 'affine has to be true/false')
self.affine = affine
else
self.affine = true
end
self.nOutput = nOutput
self.prev_batch_size = -1
if self.affine then
self.weight = torch.Tensor(nOutput):uniform()
self.bias = torch.Tensor(nOutput):zero()
self.gradWeight = torch.Tensor(nOutput)
self.gradBias = torch.Tensor(nOutput)
end
end
function InstanceNormalization:updateOutput(input)
self.output = self.output or input.new()
assert(input:size(2) == self.nOutput)
local batch_size = input:size(1)
if batch_size ~= self.prev_batch_size or (self.bn and self:type() ~= self.bn:type()) then
self.bn = nn.SpatialBatchNormalization(input:size(1)*input:size(2), self.eps, self.momentum, self.affine)
self.bn:type(self:type())
self.bn.running_mean:copy(self.running_mean:repeatTensor(batch_size))
self.bn.running_var:copy(self.running_var:repeatTensor(batch_size))
self.prev_batch_size = input:size(1)
end
-- Get statistics
self.running_mean:copy(self.bn.running_mean:view(input:size(1),self.nOutput):mean(1))
self.running_var:copy(self.bn.running_var:view(input:size(1),self.nOutput):mean(1))
-- Set params for BN
if self.affine then
self.bn.weight:copy(self.weight:repeatTensor(batch_size))
self.bn.bias:copy(self.bias:repeatTensor(batch_size))
end
local input_1obj = input:view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
self.output = self.bn:forward(input_1obj):viewAs(input)
return self.output
end
function InstanceNormalization:updateGradInput(input, gradOutput)
self.gradInput = self.gradInput or gradOutput.new()
assert(self.bn)
local input_1obj = input:view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
local gradOutput_1obj = gradOutput:view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
if self.affine then
self.bn.gradWeight:zero()
self.bn.gradBias:zero()
end
self.gradInput = self.bn:backward(input_1obj, gradOutput_1obj):viewAs(input)
if self.affine then
self.gradWeight:add(self.bn.gradWeight:view(input:size(1),self.nOutput):sum(1))
self.gradBias:add(self.bn.gradBias:view(input:size(1),self.nOutput):sum(1))
end
return self.gradInput
end
function InstanceNormalization:clearState()
self.output = self.output.new()
self.gradInput = self.gradInput.new()
self.bn:clearState()
end
function InstanceNormalization:evaluate()
end
function InstanceNormalization:training()
end