Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions LookupTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function LookupTable:updateOutput(input)
end

function LookupTable:updateGradInput(input, gradOutput)
-- the input can be of any type (as in the forward it's
-- the input can be of any type (as in the forward it's
-- converted anyway to LongTensor) thus, need to allocate
-- new memory each time the user changes the input type
if torch.type(self.gradInput) ~= torch.type(input) then
Expand Down Expand Up @@ -148,10 +148,10 @@ function LookupTable:type(type, tensorCache)

if type == 'torch.CudaTensor' then
-- CUDA uses _sorted and _indices temporary tensors
self._sorted = self.weight.new()
self._indices = self.weight.new()
self._count = self.weight.new()
self._input = self.weight.new()
self._sorted = torch.CudaLongTensor.new()
self._indices = torch.CudaLongTensor.new()
self._count = torch.CudaLongTensor.new()
self._input = torch.CudaLongTensor.new()
else
-- self._count and self._input should only be converted if using Cuda
self._count = torch.IntTensor()
Expand Down