Skip to content

Commit ad5fc25

Browse files
committed
Adding CUDA tests for StepLSTM, StepGRU, and VariableLength
- Also added a helper flag to see if cuda is available - Edited the tests to use this flag
1 parent 218ebbf commit ad5fc25

File tree

2 files changed

+222
-9
lines changed

2 files changed

+222
-9
lines changed

init.lua

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
require 'torchx'
22
local _ = require 'moses'
33
require 'nn'
4-
pcall(require, 'cunn')
4+
local _cuda, _ = pcall(require, 'cunn')
55

66
-- create global rnn table:
77
rnn = {}
8+
rnn.cuda = _cuda
89
rnn.version = 2.7 -- better support for bidirection RNNs
910

1011
-- lua 5.2 compat

test/test.lua

Lines changed: 220 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,7 +1500,7 @@ function rnntest.SequencerCriterion()
15001500
local gradInputTable = sc:backward(split:forward(input), split:forward(target))
15011501
mytester:assertTensorEq(gradInputTensor, torch.cat(gradInputTable, 1):view(gradInputTensor:size()), 0, "SequencerCriterion backward type err ")
15021502

1503-
if pcall(function() require 'cunn' end) then
1503+
if rnn.cuda then
15041504
-- test cuda()
15051505
sc.gradInput = {}
15061506
sc:cuda()
@@ -2107,7 +2107,7 @@ function rnntest.MaskZeroCriterion()
21072107
mytester:assert(math.abs(err3 - err) < 0.0000001, "MaskZeroCriterion cast fwd err")
21082108
mytester:assertTensorEq(gradInput3, gradInput:float(), 0.0000001, "MaskZeroCriterion cast bwd err")
21092109

2110-
if pcall(function() require 'cunn' end) then
2110+
if rnn.cuda then
21112111
-- test cuda
21122112
mznll:cuda()
21132113
if v2 then
@@ -2828,6 +2828,49 @@ function rnntest.SeqLSTM_Lua_vs_C()
28282828
end
28292829
end
28302830

2831+
function rnntest.SeqLSTM_cuda()
2832+
if not rnn.cuda then
2833+
return
2834+
end
2835+
2836+
local ty = torch.getdefaulttensortype()
2837+
torch.setdefaulttensortype('torch.FloatTensor')
2838+
2839+
local seqlen, batchsize = 3, 4
2840+
local inputsize, outputsize = 2, 5
2841+
2842+
local input = torch.randn(seqlen, batchsize, inputsize)
2843+
2844+
local seqlstm = nn.SeqLSTM(inputsize, outputsize)
2845+
local seqlstmCuda = nn.SeqLSTM(inputsize, outputsize)
2846+
seqlstmCuda.weight:copy(seqlstm.weight)
2847+
seqlstmCuda.bias:copy(seqlstm.bias)
2848+
seqlstmCuda:cuda()
2849+
2850+
local output = seqlstm:forward(input)
2851+
local outputCuda = seqlstmCuda:forward(input:cuda())
2852+
mytester:assertTensorEq(output, outputCuda:float(), precision)
2853+
2854+
seqlstm:zeroGradParameters()
2855+
seqlstmCuda:zeroGradParameters()
2856+
2857+
local gradOutput = torch.randn(seqlen, batchsize, outputsize)
2858+
2859+
local gradInput = seqlstm:backward(input, gradOutput)
2860+
local gradInputCuda = seqlstmCuda:backward(input:cuda(), gradOutput:cuda())
2861+
2862+
mytester:assertTensorEq(gradInput, gradInputCuda:float(), precision)
2863+
2864+
local params, gradParams = seqlstm:parameters()
2865+
local paramsCuda, gradParamsCuda = seqlstmCuda:parameters()
2866+
2867+
for i=1,#paramsCuda do
2868+
mytester:assertTensorEq(gradParams[i], gradParamsCuda[i]:float(), precision)
2869+
end
2870+
2871+
torch.setdefaulttensortype(ty)
2872+
end
2873+
28312874
function rnntest.SeqLSTM_maskzero()
28322875
-- tests that it works with non-masked inputs regardless of maskzero's value.
28332876
-- Note that more maskzero = true tests with masked inputs are in SeqLSTM unit test.
@@ -2858,7 +2901,7 @@ function rnntest.SeqLSTM_maskzero()
28582901
mytester:assertTensorEq(gradParams, gradParams2, 0.000001)
28592902
if benchmark then
28602903
local T, N, D, H = 20, 20, 50, 50
2861-
if pcall(function() require 'cunn' end) then
2904+
if rnn.cuda then
28622905
T, N, D, H = 100, 128, 250, 250
28632906
end
28642907

@@ -3603,7 +3646,7 @@ function rnntest.SeqGRU_maskzero()
36033646
mytester:assertTensorEq(gradParams, gradParams2, 0.000001)
36043647
if benchmark then
36053648
local T, N, D, H = 20, 20, 50, 50
3606-
if pcall(function() require 'cunn' end) then
3649+
if rnn.cuda then
36073650
T, N, D, H = 100, 128, 250, 250
36083651
end
36093652

@@ -3682,6 +3725,49 @@ function rnntest.SeqGRU_Lua_vs_C()
36823725
end
36833726
end
36843727

3728+
function rnntest.SeqGRU_cuda()
3729+
if not rnn.cuda then
3730+
return
3731+
end
3732+
3733+
local ty = torch.getdefaulttensortype()
3734+
torch.setdefaulttensortype('torch.FloatTensor')
3735+
3736+
local seqlen, batchsize = 3, 4
3737+
local inputsize, outputsize = 2, 5
3738+
3739+
local input = torch.randn(seqlen, batchsize, inputsize)
3740+
3741+
local seqgru = nn.SeqGRU(inputsize, outputsize)
3742+
local seqgruCuda = nn.SeqGRU(inputsize, outputsize)
3743+
seqgruCuda.weight:copy(seqgru.weight)
3744+
seqgruCuda.bias:copy(seqgru.bias)
3745+
seqgruCuda:cuda()
3746+
3747+
local output = seqgru:forward(input)
3748+
local outputCuda = seqgruCuda:forward(input:cuda())
3749+
mytester:assertTensorEq(output, outputCuda:float(), precision)
3750+
3751+
seqgru:zeroGradParameters()
3752+
seqgruCuda:zeroGradParameters()
3753+
3754+
local gradOutput = torch.randn(seqlen, batchsize, outputsize)
3755+
3756+
local gradInput = seqgru:backward(input, gradOutput)
3757+
local gradInputCuda = seqgruCuda:backward(input:cuda(), gradOutput:cuda())
3758+
3759+
mytester:assertTensorEq(gradInput, gradInputCuda:float(), precision)
3760+
3761+
local params, gradParams = seqgru:parameters()
3762+
local paramsCuda, gradParamsCuda = seqgruCuda:parameters()
3763+
3764+
for i=1,#paramsCuda do
3765+
mytester:assertTensorEq(gradParams[i], gradParamsCuda[i]:float(), precision)
3766+
end
3767+
3768+
torch.setdefaulttensortype(ty)
3769+
end
3770+
36853771
function checkgrad(opfunc, x, eps)
36863772
-- compute true gradient:
36873773
local _,dC = opfunc(x)
@@ -3865,6 +3951,51 @@ function rnntest.VariableLength_FromSamples()
38653951
end
38663952
end
38673953

3954+
function rnntest.VariableLength_FromSamples_cuda()
3955+
if not rnn.cuda then
3956+
return
3957+
end
3958+
3959+
torch.manualSeed(0)
3960+
local nSamples = 10
3961+
local maxLength = 20
3962+
for run=1,10 do
3963+
local lengths = torch.LongTensor(nSamples)
3964+
lengths:random(maxLength)
3965+
local samples = {}
3966+
for i=1,nSamples do
3967+
local t = torch.rand(lengths[i], 5)
3968+
samples[i] = t:cuda()
3969+
end
3970+
local output = torch.CudaTensor()
3971+
local mask = torch.CudaByteTensor()
3972+
local indexes, mappedLengths = output.nn.VariableLength_FromSamples(samples, output, mask)
3973+
3974+
output = output:float()
3975+
mask = mask:byte()
3976+
for i, ids in ipairs(indexes) do
3977+
local m = mask:select(2, i)
3978+
local t = output:select(2, i)
3979+
for j, sampleId in ipairs(ids) do
3980+
local l = lengths[sampleId]
3981+
-- check that the length was mapped correctly
3982+
mytester:assert(l == mappedLengths[i][j])
3983+
-- checks that the mask is 0 for valid entries
3984+
mytester:assert(math.abs(m:narrow(1, 1, l):sum()) < 0.000001)
3985+
-- checks that the valid entries are equal
3986+
mytester:assertTensorEq(t:narrow(1, 1, l), samples[sampleId]:float())
3987+
if l < m:size(1) then
3988+
mytester:assert(m[l+1] == 1)
3989+
end
3990+
if l+1 < m:size(1) then
3991+
m = m:narrow(1, l+2, m:size(1)-l-1)
3992+
t = t:narrow(1, l+2, t:size(1)-l-1)
3993+
end
3994+
end
3995+
end
3996+
end
3997+
end
3998+
38683999
function rnntest.VariableLength_ToSamples()
38694000
local nSamples = 10
38704001
local maxLength = 20
@@ -3886,6 +4017,30 @@ function rnntest.VariableLength_ToSamples()
38864017
end
38874018
end
38884019

4020+
function rnntest.VariableLength_ToSamples_cuda()
4021+
if not rnn.cuda then
4022+
return
4023+
end
4024+
local nSamples = 10
4025+
local maxLength = 20
4026+
for run=1,10 do
4027+
local lengths = torch.LongTensor(nSamples)
4028+
lengths:random(maxLength)
4029+
local samples = {}
4030+
for i=1,nSamples do
4031+
samples[i] = torch.rand(lengths[i], 5):cuda()
4032+
end
4033+
local output = torch.CudaTensor()
4034+
local mask = torch.CudaByteTensor()
4035+
local indexes, mappedLengths = output.nn.VariableLength_FromSamples(samples, output, mask)
4036+
local new_samples = output.nn.VariableLength_ToSamples(indexes, mappedLengths, output)
4037+
mytester:assert(#samples == #new_samples)
4038+
for i=1,nSamples do
4039+
mytester:assertTensorEq(samples[i]:float(), new_samples[i]:float())
4040+
end
4041+
end
4042+
end
4043+
38894044
function rnntest.VariableLength_ToFinal()
38904045
local nSamples = 10
38914046
local maxLength = 20
@@ -3910,6 +4065,30 @@ function rnntest.VariableLength_ToFinal()
39104065
end
39114066
end
39124067

4068+
function rnntest.VariableLength_ToFinal_cuda()
4069+
local nSamples = 10
4070+
local maxLength = 20
4071+
for run=1,10 do
4072+
local lengths = torch.LongTensor(nSamples)
4073+
lengths:random(maxLength)
4074+
local samples = {}
4075+
for i=1,nSamples do
4076+
local t = torch.rand(lengths[i], 5):cuda()
4077+
samples[i] = t
4078+
end
4079+
local output = torch.CudaTensor()
4080+
local mask = torch.CudaByteTensor()
4081+
local indexes, mappedLengths = output.nn.VariableLength_FromSamples(samples, output, mask)
4082+
4083+
local final = torch.CudaTensor()
4084+
output.nn.VariableLength_ToFinal(indexes, mappedLengths, output, final)
4085+
4086+
for i=1,nSamples do
4087+
mytester:assertTensorEq(samples[i]:select(1, lengths[i]):float(), final:select(1, i):float())
4088+
end
4089+
end
4090+
end
4091+
39134092
function rnntest.VariableLength_FromFinal()
39144093
torch.manualSeed(2)
39154094
local nSamples = 10
@@ -3943,6 +4122,39 @@ function rnntest.VariableLength_FromFinal()
39434122
end
39444123
end
39454124

4125+
function rnntest.VariableLength_FromFinal_cuda()
4126+
torch.manualSeed(2)
4127+
local nSamples = 10
4128+
local maxLength = 20
4129+
for run=1,1 do
4130+
local lengths = torch.LongTensor(nSamples)
4131+
lengths:random(maxLength)
4132+
local samples = {}
4133+
for i=1,nSamples do
4134+
local t = torch.rand(lengths[i], 5):cuda()
4135+
samples[i] = t
4136+
end
4137+
local output = torch.CudaTensor()
4138+
local mask = torch.CudaByteTensor()
4139+
local indexes, mappedLengths = output.nn.VariableLength_FromSamples(samples, output, mask)
4140+
4141+
local final = torch.CudaTensor()
4142+
output.nn.VariableLength_ToFinal(indexes, mappedLengths, output, final)
4143+
4144+
local re_output = torch.CudaTensor()
4145+
output.nn.VariableLength_FromFinal(indexes, mappedLengths, final, re_output)
4146+
4147+
local new_samples = output.nn.VariableLength_ToSamples(indexes, mappedLengths, re_output)
4148+
4149+
for i=1,nSamples do
4150+
if lengths[i] > 1 then
4151+
mytester:assert(new_samples[i]:narrow(1, 1, lengths[i]-1):abs():sum() < 0.000001)
4152+
end
4153+
mytester:assertTensorEq(samples[i]:select(1, lengths[i]):float(), new_samples[i]:select(1, lengths[i]):float())
4154+
end
4155+
end
4156+
end
4157+
39464158
function rnntest.VariableLength_lstm()
39474159
-- test seqlen x batchsize x hiddensize
39484160
local maxLength = 8
@@ -4987,7 +5199,7 @@ function rnntest.VRClassReward()
49875199
mytester:assertTensorEq(gradInput[2], gradInput2, 0.000001, "VRClassReward backward baseline err")
49885200
mytester:assert(math.abs(gradInput[1]:sum()) < 0.000001, "VRClassReward backward class err")
49895201

4990-
if pcall(function() require 'cunn' end) then
5202+
if rnn.cuda then
49915203
local gradInput = {gradInput[1], gradInput[2]}
49925204
input[1], input[2] = input[1]:cuda(), input[2]:cuda()
49935205
target = target:cuda()
@@ -5859,7 +6071,7 @@ function rnntest.NCE_main()
58596071
end
58606072

58616073

5862-
if pcall(function() require 'cunn' end) then
6074+
if rnn.cuda then
58636075
-- test training with cuda
58646076

58656077
ncem:cuda()
@@ -6082,7 +6294,7 @@ function rnntest.NCE_batchnoise()
60826294
end
60836295

60846296

6085-
if pcall(function() require 'cunn' end) then
6297+
if rnn.cuda then
60866298
-- test training with cuda
60876299

60886300
ncem:cuda()
@@ -6168,7 +6380,7 @@ function rnntest.NCE_multicuda()
61686380
if not pcall(function() require 'torchx' end) then
61696381
return
61706382
end
6171-
if not pcall(function() require 'cunn' end) then
6383+
if not rnn.cuda then
61726384
return
61736385
end
61746386
if cutorch.getDeviceCount() < 2 then

0 commit comments

Comments
 (0)