From 3ffb30c6126c618eb9df57efa7a9b0738a76aacc Mon Sep 17 00:00:00 2001 From: TITC Date: Thu, 19 May 2022 14:29:06 +0800 Subject: [PATCH] remove nn.DataParallel --- pix2tex/eval.py | 3 +- pix2tex/model/settings/config-vit.yaml | 1 + pix2tex/model/settings/config.yaml | 1 + pix2tex/models/utils.py | 20 ++---------- pix2tex/train.py | 42 +++++++++++++++++++++++--- 5 files changed, 43 insertions(+), 24 deletions(-) diff --git a/pix2tex/eval.py b/pix2tex/eval.py index f735eab..c53ea53 100644 --- a/pix2tex/eval.py +++ b/pix2tex/eval.py @@ -52,8 +52,7 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i continue encoded = model.encoder(im.to(device)) #loss = decoder(tgt_seq, mask=tgt_mask, context=encoded) - generate = model.decoder.module.generate if torch.cuda.device_count() > 1 else model.decoder.generate - dec = generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len, + dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len, eos_token=args.pad_token, context=encoded, temperature=args.get('temperature', .2)) pred = detokenize(dec, dataset.tokenizer) truth = detokenize(seq['input_ids'], dataset.tokenizer) diff --git a/pix2tex/model/settings/config-vit.yaml b/pix2tex/model/settings/config-vit.yaml index 002724f..f434be2 100644 --- a/pix2tex/model/settings/config-vit.yaml +++ b/pix2tex/model/settings/config-vit.yaml @@ -1,3 +1,4 @@ +gpu_devices: null #[0,1,2,3,4,5,6,7] betas: - 0.9 - 0.999 diff --git a/pix2tex/model/settings/config.yaml b/pix2tex/model/settings/config.yaml index f38e181..fa1b3b7 100644 --- a/pix2tex/model/settings/config.yaml +++ b/pix2tex/model/settings/config.yaml @@ -1,3 +1,4 @@ +gpu_devices: null #[0,1,2,3,4,5,6,7] backbone_layers: - 2 - 3 diff --git a/pix2tex/models/utils.py b/pix2tex/models/utils.py index d230a3f..6d20420 100644 --- a/pix2tex/models/utils.py +++ b/pix2tex/models/utils.py @@ -18,7 +18,7 @@ def forward(self, x: torch.Tensor): self.args.max_seq_len, eos_token=self.args.eos_token, context=self.encoder(x)) -def get_model(args, training=False): +def get_model(args): if args.encoder_structure.lower() == 'vit': encoder = vit.get_encoder(args) elif args.encoder_structure.lower() == 'hybrid': @@ -26,27 +26,11 @@ def get_model(args, training=False): else: raise NotImplementedError('Encoder structure "%s" not supported.' % args.encoder_structure) decoder = transformer.get_decoder(args) - num_available_gpus = torch.cuda.device_count() - if num_available_gpus > 1: - encoder = nn.DataParallel(encoder) - decoder = nn.DataParallel(decoder) encoder.to(args.device) decoder.to(args.device) model = Model(encoder, decoder, args) if args.wandb: import wandb wandb.watch(model) - if training: - # check if largest batch can be handled by system - try: - batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize - for _ in range(5): - im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float() - seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long() - decoder(seq, context=encoder(im)).sum().backward() - except RuntimeError: - raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize."%(batchsize, args.max_height, args.max_width)) - model.zero_grad() - torch.cuda.empty_cache() - del im, seq + return model diff --git a/pix2tex/train.py b/pix2tex/train.py index 09d44d7..7300c7e 100644 --- a/pix2tex/train.py +++ b/pix2tex/train.py @@ -8,13 +8,44 @@ from munch import Munch from tqdm.auto import tqdm import wandb - +import torch.nn as nn from pix2tex.eval import evaluate from pix2tex.models import get_model # from pix2tex.utils import * from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler +def data_parallel(module, inputs, device_ids, output_device=None, **kwargs): + if not device_ids or len(device_ids) == 1: + return module(inputs, **kwargs) + if output_device is None: + output_device = device_ids[0] + replicas = nn.parallel.replicate(module, device_ids) + inputs = nn.parallel.scatter(inputs, device_ids) #Slices tensors into approximately equal chunks and distributes them across given GPUs. + kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors. + replicas = replicas[:len(inputs)] + kwargs = kwargs[:len(inputs)] + outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs) + return nn.parallel.gather(outputs, output_device) + + +def gpu_memory_check(model, args): + # check if largest batch can be handled by system + try: + batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize + for _ in range(5): + im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float() + seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long() + # model.decoder(seq, context=model.encoder(im)).sum().backward() + encoded = data_parallel(model.encoder, inputs=im, device_ids=args.gpu_devices) + loss = data_parallel(model.decoder, inputs=seq, device_ids=args.gpu_devices, context=encoded) + loss.sum().backward() + except RuntimeError: + raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize."%(batchsize, args.max_height, args.max_width)) + model.zero_grad() + torch.cuda.empty_cache() + del im, seq + def train(args): dataloader = Im2LatexDataset().load(args.data) @@ -24,7 +55,8 @@ def train(args): valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True) valdataloader.update(**valargs) device = args.device - model = get_model(args, training=True) + model = get_model(args) + gpu_memory_check(model, args) if args.load_chkpt is not None: model.load_state_dict(torch.load(args.load_chkpt, map_location=device)) encoder, decoder = model.encoder, model.decoder @@ -53,8 +85,10 @@ def save_models(e, step=0): total_loss = 0 for j in range(0, len(im), microbatch): tgt_seq, tgt_mask = seq['input_ids'][j:j+microbatch].to(device), seq['attention_mask'][j:j+microbatch].bool().to(device) - encoded = encoder(im[j:j+microbatch].to(device)) - loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)*microbatch/args.batchsize + # encoded = encoder(im[j:j+microbatch].to(device)) + encoded = data_parallel(encoder, inputs=im[j:j+microbatch].to(device), device_ids=args.gpu_devices) + # loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)*microbatch/args.batchsize + loss = data_parallel(module=decoder, inputs=tgt_seq, device_ids=args.gpu_devices, mask=tgt_mask, context=encoded)*microbatch/args.batchsize # loss.backward() loss.mean().backward()# data parallism loss is a vector total_loss += loss.mean().item()