Skip to content

Commit

Permalink
.cuda() to .to(device)
Browse files Browse the repository at this point in the history
  • Loading branch information
Baek JeongHun committed Aug 3, 2019
1 parent 7498f48 commit 8f7255f
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 27 deletions.
12 changes: 5 additions & 7 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from utils import CTCLabelConverter, AttnLabelConverter
from dataset import RawDataset, AlignCollate
from model import Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def demo(opt):
Expand All @@ -24,10 +25,7 @@ def demo(opt):
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
opt.SequenceModeling, opt.Prediction)

model = torch.nn.DataParallel(model)
if torch.cuda.is_available():
model = model.cuda()
model = torch.nn.DataParallel(model).to(device)

# load model
print('loading pretrained model from %s' % opt.saved_model)
Expand All @@ -47,10 +45,10 @@ def demo(opt):
with torch.no_grad():
for image_tensors, image_path_list in demo_loader:
batch_size = image_tensors.size(0)
image = image_tensors.cuda()
image = image_tensors.to(device)
# For max length prediction
length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)
length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

if 'CTC' in opt.Prediction:
preds = model(image, text_for_pred).log_softmax(2)
Expand Down
13 changes: 7 additions & 6 deletions modules/prediction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Attention(nn.Module):
Expand All @@ -15,7 +16,7 @@ def __init__(self, input_size, hidden_size, num_classes):
def _char_to_onehot(self, input_char, onehot_dim=38):
input_char = input_char.unsqueeze(1)
batch_size = input_char.size(0)
one_hot = torch.cuda.FloatTensor(batch_size, onehot_dim).zero_()
one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
one_hot = one_hot.scatter_(1, input_char, 1)
return one_hot

Expand All @@ -29,9 +30,9 @@ def forward(self, batch_H, text, is_train=True, batch_max_length=25):
batch_size = batch_H.size(0)
num_steps = batch_max_length + 1 # +1 for [s] at end of sentence.

output_hiddens = torch.cuda.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0)
hidden = (torch.cuda.FloatTensor(batch_size, self.hidden_size).fill_(0),
torch.cuda.FloatTensor(batch_size, self.hidden_size).fill_(0))
output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device)
hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device))

if is_train:
for i in range(num_steps):
Expand All @@ -43,8 +44,8 @@ def forward(self, batch_H, text, is_train=True, batch_max_length=25):
probs = self.generator(output_hiddens)

else:
targets = torch.cuda.LongTensor(batch_size).fill_(0) # [GO] token
probs = torch.cuda.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0)
targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token
probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device)

for i in range(num_steps):
char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes)
Expand Down
3 changes: 2 additions & 1 deletion modules/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class TPS_SpatialTransformerNetwork(nn.Module):
Expand Down Expand Up @@ -149,7 +150,7 @@ def build_P_prime(self, batch_C_prime):
batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
batch_size, 3, 2).float().cuda()), dim=1) # batch_size x F+3 x 2
batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2
batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
return batch_P_prime # batch_size x n x 2
13 changes: 7 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from utils import CTCLabelConverter, AttnLabelConverter, Averager
from dataset import hierarchical_dataset, AlignCollate
from model import Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False):
Expand Down Expand Up @@ -76,10 +77,10 @@ def validation(model, criterion, evaluation_loader, converter, opt):
for i, (image_tensors, labels) in enumerate(evaluation_loader):
batch_size = image_tensors.size(0)
length_of_data = length_of_data + batch_size
image = image_tensors.cuda()
image = image_tensors.to(device)
# For max length prediction
length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)
length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)

Expand Down Expand Up @@ -146,7 +147,7 @@ def test(opt):
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
opt.SequenceModeling, opt.Prediction)
model = torch.nn.DataParallel(model).cuda()
model = torch.nn.DataParallel(model).to(device)

# load model
print('loading pretrained model from %s' % opt.saved_model)
Expand All @@ -160,9 +161,9 @@ def test(opt):

""" setup loss """
if 'CTC' in opt.Prediction:
criterion = torch.nn.CTCLoss(zero_infinity=True).cuda()
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
else:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0

""" evaluation """
model.eval()
Expand Down
9 changes: 5 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
from model import Model
from test import validation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def train(opt):
Expand Down Expand Up @@ -63,7 +64,7 @@ def train(opt):
continue

# data parallel for multi-GPU
model = torch.nn.DataParallel(model).cuda()
model = torch.nn.DataParallel(model).to(device)
model.train()
if opt.continue_model != '':
print(f'loading pretrained model from {opt.continue_model}')
Expand All @@ -73,9 +74,9 @@ def train(opt):

""" setup loss """
if 'CTC' in opt.Prediction:
criterion = torch.nn.CTCLoss(zero_infinity=True).cuda()
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
else:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
# loss averager
loss_avg = Averager()

Expand Down Expand Up @@ -121,7 +122,7 @@ def train(opt):
while(True):
# train part
image_tensors, labels = train_dataset.get_batch()
image = image_tensors.cuda()
image = image_tensors.to(device)
text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
batch_size = image.size(0)

Expand Down
7 changes: 4 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class CTCLabelConverter(object):
Expand Down Expand Up @@ -79,13 +80,13 @@ def encode(self, text, batch_max_length=25):
# batch_max_length = max(length) # this is not allowed for multi-gpu setting
batch_max_length += 1
# additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
batch_text = torch.cuda.LongTensor(len(text), batch_max_length + 1).fill_(0)
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
for i, t in enumerate(text):
text = list(t)
text.append('[s]')
text = [self.dict[char] for char in text]
batch_text[i][1:1 + len(text)] = torch.cuda.LongTensor(text) # batch_text[:, 0] = [GO] token
return (batch_text, torch.cuda.IntTensor(length))
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
return (batch_text.to(device), torch.IntTensor(length).to(device))

def decode(self, text_index, length):
""" convert text-index into text-label. """
Expand Down

0 comments on commit 8f7255f

Please sign in to comment.