Skip to content

Commit

Permalink
Fix Embedding and CosineEmbeddingLoss on non-float CUDA (pytorch#965)
Browse files Browse the repository at this point in the history
fmassa authored and soumith committed Mar 9, 2017
1 parent b2d077d commit b785ed0
Showing 2 changed files with 3 additions and 9 deletions.
10 changes: 2 additions & 8 deletions torch/nn/_functions/loss.py
Original file line number Diff line number Diff line change
@@ -9,20 +9,14 @@ def __init__(self, margin=0, size_average=True):
self.margin = margin
self.size_average = size_average

def _new_idx(self, input):
if torch.typename(input) == 'torch.cuda.FloatTensor':
return torch.cuda.ByteTensor()
else:
return torch.ByteTensor()

def forward(self, input1, input2, y):
self.w1 = input1.new()
self.w22 = input1.new()
self.w = input1.new()
self.w32 = input1.new()
self._outputs = input1.new()

_idx = self._new_idx(input1)
_idx = input1.new().byte()

buffer = torch.mul(input1, input2)
torch.sum(buffer, 1, out=self.w1)
@@ -61,7 +55,7 @@ def backward(self, grad_output):
v1, v2, y = self.saved_tensors

buffer = v1.new()
_idx = self._new_idx(v1)
_idx = v1.new().byte()

gw1 = grad_output.new()
gw2 = grad_output.new()
2 changes: 1 addition & 1 deletion torch/nn/_functions/thnn/sparse.py
Original file line number Diff line number Diff line change
@@ -77,7 +77,7 @@ def backward(self, grad_output):
indices = indices.view(-1)

with torch.cuda.device_of(grad_output):
if torch.typename(grad_output) == 'torch.cuda.FloatTensor':
if grad_output.is_cuda:
_sorted = torch.cuda.LongTensor()
_indices = torch.cuda.LongTensor()
_count = torch.cuda.LongTensor()

0 comments on commit b785ed0

Please sign in to comment.