Skip to content

Commit

Permalink
Add .t7 file reader
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke committed Nov 24, 2016
1 parent 8b492bb commit bcfa2d6
Show file tree
Hide file tree
Showing 16 changed files with 696 additions and 19 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap
torch/csrc/nn/THCUNN.cpp
test/data/legacy_modules.t7
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
Expand Down
Binary file removed test/data/legacy_modules.t7
Binary file not shown.
6 changes: 3 additions & 3 deletions test/test_legacy_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,16 @@ def _do_test(self, test_case, module, input):
OldModuleTest(nn.Sum,
(1,),
input_size=(2, 4, 5),
reference_fn=lambda i,_: i.sum(1)),
reference_fn=lambda i,_: i.sum(1).squeeze(1)),
OldModuleTest(nn.Sum,
(1, True),
input_size=(2, 4, 5),
reference_fn=lambda i,_: i.sum(1).div(i.size(1)),
reference_fn=lambda i,_: i.sum(1).div(i.size(1)).squeeze(1),
desc='sizeAverage'),
OldModuleTest(nn.Mean,
(1,),
input_size=(2, 4, 5),
reference_fn=lambda i,_: torch.mean(i, 1)),
reference_fn=lambda i,_: torch.mean(i, 1).squeeze(1)),
OldModuleTest(lambda: nn.Sequential().add(nn.GradientReversal()).add(nn.GradientReversal()),
input_size=(4, 3, 2, 2),
fullname='GradientReversal'),
Expand Down
118 changes: 118 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import sys
import os
import math
Expand All @@ -9,10 +10,12 @@
import traceback
import torch
import torch.cuda
import warnings
from torch.autograd import Variable
from torch.utils.trainer import Trainer
from torch.utils.trainer.plugins import *
from torch.utils.trainer.plugins.plugin import Plugin
from torch.utils.serialization import load_lua

HAS_CUDA = torch.cuda.is_available()

Expand Down Expand Up @@ -245,5 +248,120 @@ def test_gpu(self):
lambda: gpulib.cuda_func(ctensor.storage(), 2, 1.5))


class TestLuaReader(TestCase):

@staticmethod
def _module_test(name, test):
def do_test(self):
module = test['module']
input = test['input']
grad_output = test['grad_output']
if hasattr(self, '_transform_' + name):
input = getattr(self, '_transform_' + name)(input)
output = module.forward(input)
module.zeroGradParameters()
grad_input = module.backward(input, grad_output)
self.assertEqual(output, test['output'])
self.assertEqual(grad_input, test['grad_input'])
if module.parameters() is not None:
params, d_params = module.parameters()
self.assertEqual(params, test['params'])
self.assertEqual(d_params, test['d_params'])
else:
self.assertFalse('params' in test and test['params'])
self.assertFalse('params' in test and test['d_params'])
return do_test

@staticmethod
def _criterion_test(name, test):
def do_test(self):
module = test['module']
input = test['input']
if name == 'L1Cost':
target = None
else:
target = test['target']
if hasattr(self, '_transform_' + name):
input, target = getattr(self, '_transform_' + name)(input, target)

output = module.forward(input, target)
grad_input = module.backward(input, target)
self.assertEqual(output, test['loss'])
self.assertEqual(grad_input, test['grad_input'])
return do_test

@classmethod
def _download_data(cls, test_file_path):
if os.path.exists(test_file_path):
return
print('Downloading test file for TestLuaReader.')
DATA_URL = 'https://s3.amazonaws.com/pytorch/legacy_modules.t7'
urllib = cls._get_urllib('request')
data = urllib.urlopen(DATA_URL).read()
with open(test_file_path, 'wb') as f:
f.write(data)

@staticmethod
def _get_urllib(submodule):
if sys.version_info < (3,):
import urllib2
return urllib2
else:
import urllib.error
import urllib.request
return getattr(urllib, submodule)

@classmethod
def init(cls):
data_dir = os.path.join(os.path.dirname(__file__), 'data')
test_file_path = os.path.join(data_dir, 'legacy_modules.t7')
urllib = cls._get_urllib('error')
try:
cls._download_data(test_file_path)
except urllib.URLError as e:
warnings.warn(("Couldn't download the test file for TestLuaReader! "
"Tests will be incomplete!"), RuntimeWarning)
return

tests = load_lua(test_file_path)
for name, test in tests['modules'].items():
test_name = 'test_' + name.replace('nn.', '')
setattr(cls, test_name, cls._module_test(name, test))
for name, test in tests['criterions'].items():
test_name = 'test_' + name.replace('nn.', '')
setattr(cls, test_name, cls._criterion_test(name, test))

def _transform_Index(self, input):
return [input[0], input[1].sub(1)]

def _transform_LookupTable(self, input):
return input.sub(1)

def _transform_MultiLabelMarginCriterion(self, input, target):
return input, target.sub(1)

def _transform_ClassNLLCriterion(self, input, target):
return input, target.sub(1)

def _transform_SpatialClassNLLCriterion(self, input, target):
return input, target.sub(1)

def _transform_ClassSimplexCriterion(self, input, target):
return input, target.sub(1)

def _transform_CrossEntropyCriterion(self, input, target):
return input, target.sub(1)

def _transform_ParallelCriterion(self, input, target):
return input, [target[0].sub(1), target[1]]

def _transform_MultiCriterion(self, input, target):
return input, target.sub(1)

def _transform_MultiMarginCriterion(self, input, target):
return input, target.sub(1)


TestLuaReader.init()
if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions torch/legacy/nn/BCECriterion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from .Criterion import Criterion

# TODO: use THNN
class BCECriterion(Criterion):
eps = 1e-12

Expand Down
2 changes: 2 additions & 0 deletions torch/legacy/nn/ClassNLLCriterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, weights=None, sizeAverage=True):
self.total_weight_tensor = torch.ones(1)

def updateOutput(self, input, target):
target = target.long()
self._backend.ClassNLLCriterion_updateOutput(
self._backend.library_state,
input,
Expand All @@ -29,6 +30,7 @@ def updateOutput(self, input, target):

def updateGradInput(self, input, target):
self.gradInput.resize_as_(input).zero_()
target = target.long()

self._backend.ClassNLLCriterion_updateGradInput(
self._backend.library_state,
Expand Down
6 changes: 3 additions & 3 deletions torch/legacy/nn/Concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class Concat(Container):

def __init__(self, dimension):
super(Concat, self).__init__()
self.size = torch.Size()
self.outputSize = torch.Size()
self.dimension = dimension

def updateOutput(self, input):
Expand All @@ -17,8 +17,8 @@ def updateOutput(self, input):
size = list(currentOutput.size())
else:
size[self.dimension] += currentOutput.size(self.dimension)
self.size = torch.Size(size)
self.output.resize_(self.size)
self.outputSize = torch.Size(size)
self.output.resize_(self.outputSize)

offset = 0
for i, module in enumerate(self.modules):
Expand Down
2 changes: 1 addition & 1 deletion torch/legacy/nn/CriterionTable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def updateOutput(self, input):
self.output = self.criterion.updateOutput(*input)
return self.output

def updateGradInput(self, input):
def updateGradInput(self, input, grad_output):
self.criterion.updateGradInput(*input)
return self.gradInput

12 changes: 6 additions & 6 deletions torch/legacy/nn/DepthConcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ class DepthConcat(Concat):

def windowNarrow(self, output, currentOutput, offset):
outputWindow = output.narrow(self.dimension, offset, currentOutput.size(self.dimension))
for dim in range(len(self.size)):
for dim in range(len(self.outputSize)):
currentSize = currentOutput.size(dim)
if dim != self.dimension and self.size[dim] != currentSize:
if dim != self.dimension and self.outputSize[dim] != currentSize:
# 5x5 vs 3x3 -> start = [(5-3)/2] + 1 = 2 (1 pad each side)
# 9x9 vs 5x5 -> start = [(9-5)/2] + 1 = 3 (2 pad each side)
# 9x9 vs 4x4 -> start = [(9-4)/2] + 1 = 3.5 (2 pad, 3 pad)
start = int(math.floor(((self.size[dim] - currentSize) / 2)))
start = int(math.floor(((self.outputSize[dim] - currentSize) / 2)))
outputWindow = outputWindow.narrow(dim, start, currentSize)
return outputWindow

Expand All @@ -37,13 +37,13 @@ def updateOutput(self, input):
size = list(currentOutput.size())
else:
size[self.dimension] += currentOutput.size(self.dimension)
for dim in range(len(self.size)):
for dim in range(len(self.outputSize)):
if dim != self.dimension:
# take the maximum size (shouldn't change anything for batch dim)
size[dim] = max(size[dim], currentOutput.size(dim))

self.size = torch.Size(size)
self.output.resize_(self.size).zero_() # zero for padding
self.outputSize = torch.Size(size)
self.output.resize_(self.outputSize).zero_() # zero for padding

offset = 0
for i, module in enumerate(self.modules):
Expand Down
2 changes: 1 addition & 1 deletion torch/legacy/nn/NarrowTable.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def updateGradInput(self, input, gradOutput):

for i in range(len(input)):
if i < self.offset or i >= self.offset + self.length:
self.gradInput[i] = recursiveResizeAs(self.gradInput[i] or torch.Tensor(), input[i])
self.gradInput[i] = recursiveResizeAs(self.gradInput[i] or torch.Tensor(), input[i])[0]
recursiveFill(self.gradInput[i], 0)

return self.gradInput
Expand Down
2 changes: 1 addition & 1 deletion torch/legacy/nn/SpatialReflectionPadding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ def updateGradInput(self, input, gradOutput):

def __repr__(self):
s = super(SpatialReflectionPadding, self).__repr__()
s += '({}, {}, {}, {})'.foramat(self.pad_l, self.pad_r, self.pad_t, self.pad_b)
s += '({}, {}, {}, {})'.format(self.pad_l, self.pad_r, self.pad_t, self.pad_b)
return s

4 changes: 2 additions & 2 deletions torch/legacy/nn/SpatialReplicationPadding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def updateGradInput(self, input, gradOutput):
return self.gradInput

def __repr__(self):
s = super(SpatialReflectionPadding, self).__repr__()
s += '({}, {}, {}, {})'.foramat(self.pad_l, self.pad_r, self.pad_t, self.pad_b)
s = super(SpatialReplicationPadding, self).__repr__()
s += '({}, {}, {}, {})'.format(self.pad_l, self.pad_r, self.pad_t, self.pad_b)
return s

2 changes: 2 additions & 0 deletions torch/legacy/nn/Sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def updateOutput(self, input):
torch.sum(self.output, input, dimension)
if self.sizeAverage:
self.output.div_(input.size(dimension))
if self.output.dim() > 1:
self.output.set_(self.output.select(dimension, 0))

return self.output

Expand Down
2 changes: 0 additions & 2 deletions torch/legacy/nn/VolumetricAveragePooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def updateGradInput(self, input, gradOutput):
def __repr__(self):
s = super(VolumetricAveragePooling, self).__repr__()
s += '({}x{}x{}, {}, {}, {}'.format(self.kT, self.kW, self.kH, self.dT, self.dW, self.dH)
if self.padT != 0 or self.padW != 0 or self.padH != 0:
s += ', {}, {}, {}'.format(self.padT, self.padW, self.padH)
s += ')'
return s

2 changes: 2 additions & 0 deletions torch/utils/serialization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .read_lua_file import load_lua, T7Reader
Loading

0 comments on commit bcfa2d6

Please sign in to comment.