From 8f7255fb4b177b3f5c47df9cdec7649c0a08469e Mon Sep 17 00:00:00 2001 From: Baek JeongHun Date: Sat, 3 Aug 2019 08:03:46 +0000 Subject: [PATCH] .cuda() to .to(device) --- demo.py | 12 +++++------- modules/prediction.py | 13 +++++++------ modules/transformation.py | 3 ++- test.py | 13 +++++++------ train.py | 9 +++++---- utils.py | 7 ++++--- 6 files changed, 30 insertions(+), 27 deletions(-) diff --git a/demo.py b/demo.py index 2c088006a1..8e0da24109 100755 --- a/demo.py +++ b/demo.py @@ -8,6 +8,7 @@ from utils import CTCLabelConverter, AttnLabelConverter from dataset import RawDataset, AlignCollate from model import Model +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def demo(opt): @@ -24,10 +25,7 @@ def demo(opt): print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) - - model = torch.nn.DataParallel(model) - if torch.cuda.is_available(): - model = model.cuda() + model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) @@ -47,10 +45,10 @@ def demo(opt): with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) - image = image_tensors.cuda() + image = image_tensors.to(device) # For max length prediction - length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) - text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0) + length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) + text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred).log_softmax(2) diff --git a/modules/prediction.py b/modules/prediction.py index 37afab4c06..c8a40af0ec 100755 --- a/modules/prediction.py +++ b/modules/prediction.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class Attention(nn.Module): @@ -15,7 +16,7 @@ def __init__(self, input_size, hidden_size, num_classes): def _char_to_onehot(self, input_char, onehot_dim=38): input_char = input_char.unsqueeze(1) batch_size = input_char.size(0) - one_hot = torch.cuda.FloatTensor(batch_size, onehot_dim).zero_() + one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) one_hot = one_hot.scatter_(1, input_char, 1) return one_hot @@ -29,9 +30,9 @@ def forward(self, batch_H, text, is_train=True, batch_max_length=25): batch_size = batch_H.size(0) num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. - output_hiddens = torch.cuda.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0) - hidden = (torch.cuda.FloatTensor(batch_size, self.hidden_size).fill_(0), - torch.cuda.FloatTensor(batch_size, self.hidden_size).fill_(0)) + output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device) + hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), + torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device)) if is_train: for i in range(num_steps): @@ -43,8 +44,8 @@ def forward(self, batch_H, text, is_train=True, batch_max_length=25): probs = self.generator(output_hiddens) else: - targets = torch.cuda.LongTensor(batch_size).fill_(0) # [GO] token - probs = torch.cuda.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0) + targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token + probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device) for i in range(num_steps): char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) diff --git a/modules/transformation.py b/modules/transformation.py index 893147d014..e543e0cd97 100755 --- a/modules/transformation.py +++ b/modules/transformation.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class TPS_SpatialTransformerNetwork(nn.Module): @@ -149,7 +150,7 @@ def build_P_prime(self, batch_C_prime): batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( - batch_size, 3, 2).float().cuda()), dim=1) # batch_size x F+3 x 2 + batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 return batch_P_prime # batch_size x n x 2 diff --git a/test.py b/test.py index 3e3ea40745..23320cfc71 100755 --- a/test.py +++ b/test.py @@ -12,6 +12,7 @@ from utils import CTCLabelConverter, AttnLabelConverter, Averager from dataset import hierarchical_dataset, AlignCollate from model import Model +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False): @@ -76,10 +77,10 @@ def validation(model, criterion, evaluation_loader, converter, opt): for i, (image_tensors, labels) in enumerate(evaluation_loader): batch_size = image_tensors.size(0) length_of_data = length_of_data + batch_size - image = image_tensors.cuda() + image = image_tensors.to(device) # For max length prediction - length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) - text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0) + length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) + text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length) @@ -146,7 +147,7 @@ def test(opt): print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) - model = torch.nn.DataParallel(model).cuda() + model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) @@ -160,9 +161,9 @@ def test(opt): """ setup loss """ if 'CTC' in opt.Prediction: - criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() + criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: - criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0 + criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 """ evaluation """ model.eval() diff --git a/train.py b/train.py index 6ec1588f36..301c2919f5 100755 --- a/train.py +++ b/train.py @@ -16,6 +16,7 @@ from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset from model import Model from test import validation +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def train(opt): @@ -63,7 +64,7 @@ def train(opt): continue # data parallel for multi-GPU - model = torch.nn.DataParallel(model).cuda() + model = torch.nn.DataParallel(model).to(device) model.train() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') @@ -73,9 +74,9 @@ def train(opt): """ setup loss """ if 'CTC' in opt.Prediction: - criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() + criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: - criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0 + criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() @@ -121,7 +122,7 @@ def train(opt): while(True): # train part image_tensors, labels = train_dataset.get_batch() - image = image_tensors.cuda() + image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) diff --git a/utils.py b/utils.py index e225bfdd1e..591ef57893 100755 --- a/utils.py +++ b/utils.py @@ -1,4 +1,5 @@ import torch +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class CTCLabelConverter(object): @@ -79,13 +80,13 @@ def encode(self, text, batch_max_length=25): # batch_max_length = max(length) # this is not allowed for multi-gpu setting batch_max_length += 1 # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. - batch_text = torch.cuda.LongTensor(len(text), batch_max_length + 1).fill_(0) + batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) for i, t in enumerate(text): text = list(t) text.append('[s]') text = [self.dict[char] for char in text] - batch_text[i][1:1 + len(text)] = torch.cuda.LongTensor(text) # batch_text[:, 0] = [GO] token - return (batch_text, torch.cuda.IntTensor(length)) + batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token + return (batch_text.to(device), torch.IntTensor(length).to(device)) def decode(self, text_index, length): """ convert text-index into text-label. """