diff --git a/.gitignore b/.gitignore index 53d473c..e80e5f6 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,5 @@ pix2tex/model/checkpoints/** !**/.gitkeep .vscode .DS_Store +test/* + diff --git a/pix2tex/cli.py b/pix2tex/cli.py index 31bd130..d42339f 100644 --- a/pix2tex/cli.py +++ b/pix2tex/cli.py @@ -75,6 +75,7 @@ def __init__(self, arguments=None): download_checkpoints() self.model = get_model(self.args) self.model.load_state_dict(torch.load(self.args.checkpoint, map_location=self.args.device)) + self.model.eval() if 'image_resizer.pth' in os.listdir(os.path.dirname(self.args.checkpoint)) and not arguments.no_resize: self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(self.args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05, @@ -123,13 +124,8 @@ def __call__(self, img=None, resize=True) -> str: t = test_transform(image=img)['image'][:1].unsqueeze(0) im = t.to(self.args.device) - with torch.no_grad(): - self.model.eval() - device = self.args.device - encoded = self.model.encoder(im.to(device)) - dec = self.model.decoder.generate(torch.LongTensor([self.args.bos_token])[:, None].to(device), self.args.max_seq_len, - eos_token=self.args.eos_token, context=encoded.detach(), temperature=self.args.get('temperature', .25)) - pred = post_process(token2str(dec, self.tokenizer)[0]) + dec = self.model.generate(im.to(self.args.device), temperature=self.args.get('temperature', .25)) + pred = post_process(token2str(dec, self.tokenizer)[0]) try: clipboard.copy(pred) except: diff --git a/pix2tex/eval.py b/pix2tex/eval.py index c53ea53..8742988 100644 --- a/pix2tex/eval.py +++ b/pix2tex/eval.py @@ -50,10 +50,8 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i for i, (seq, im) in pbar: if seq is None or im is None: continue - encoded = model.encoder(im.to(device)) #loss = decoder(tgt_seq, mask=tgt_mask, context=encoded) - dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len, - eos_token=args.pad_token, context=encoded, temperature=args.get('temperature', .2)) + dec = model.generate(im.to(device), temperature=args.get('temperature', .2)) pred = detokenize(dec, dataset.tokenizer) truth = detokenize(seq['input_ids'], dataset.tokenizer) bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth])) diff --git a/pix2tex/model/settings/config-vit.yaml b/pix2tex/model/settings/config-vit.yaml new file mode 100644 index 0000000..f434be2 --- /dev/null +++ b/pix2tex/model/settings/config-vit.yaml @@ -0,0 +1,52 @@ +gpu_devices: null #[0,1,2,3,4,5,6,7] +betas: +- 0.9 +- 0.999 +batchsize: 64 +bos_token: 1 +channels: 1 +data: dataset/data/train.pkl +debug: false +decoder_args: + attn_on_attn: true + cross_attend: true + ff_glu: true + rel_pos_bias: false + use_scalenorm: false +dim: 256 +emb_dropout: 0 +encoder_depth: 4 +eos_token: 2 +epochs: 10 +gamma: 0.9995 +heads: 8 +id: null +load_chkpt: null +lr: 0.0005 +lr_step: 30 +max_height: 192 +max_seq_len: 512 +max_width: 672 +min_height: 32 +min_width: 32 +micro_batchsize: 64 +model_path: checkpoints_add +name: pix2tex-vit +num_layers: 4 +num_tokens: 8000 +optimizer: Adam +output_path: outputs +pad: false +pad_token: 0 +patch_size: 16 +sample_freq: 1000 +save_freq: 5 +scheduler: StepLR +seed: 42 +encoder_structure: vit +temperature: 0.2 +test_samples: 5 +testbatchsize: 20 +tokenizer: dataset/tokenizer.json +valbatches: 100 +valdata: dataset/data/val.pkl \ No newline at end of file diff --git a/pix2tex/model/settings/config.yaml b/pix2tex/model/settings/config.yaml index 90bae70..fa1b3b7 100644 --- a/pix2tex/model/settings/config.yaml +++ b/pix2tex/model/settings/config.yaml @@ -1,3 +1,4 @@ +gpu_devices: null #[0,1,2,3,4,5,6,7] backbone_layers: - 2 - 3 @@ -45,6 +46,7 @@ sample_freq: 3000 save_freq: 5 scheduler: StepLR seed: 42 +encoder_structure: hybrid temperature: 0.2 test_samples: 5 testbatchsize: 20 diff --git a/pix2tex/model/settings/debug.yaml b/pix2tex/model/settings/debug.yaml index bcb8081..94e3b77 100644 --- a/pix2tex/model/settings/debug.yaml +++ b/pix2tex/model/settings/debug.yaml @@ -51,6 +51,7 @@ decoder_args: heads: 8 num_tokens: 8000 max_seq_len: 1024 +encoder_structure: hybrid # Other seed: 42 diff --git a/pix2tex/models.py b/pix2tex/models.py deleted file mode 100644 index 42631a0..0000000 --- a/pix2tex/models.py +++ /dev/null @@ -1,160 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -# from x_transformers import * -from x_transformers import TransformerWrapper, Decoder -from x_transformers.autoregressive_wrapper import AutoregressiveWrapper, top_k, top_p, entmax, ENTMAX_ALPHA -from timm.models.vision_transformer import VisionTransformer -from timm.models.vision_transformer_hybrid import HybridEmbed -from timm.models.resnetv2 import ResNetV2 -from timm.models.layers import StdConv2dSame -from einops import rearrange, repeat - - -class CustomARWrapper(AutoregressiveWrapper): - def __init__(self, *args, **kwargs): - super(CustomARWrapper, self).__init__(*args, **kwargs) - - @torch.no_grad() - def generate(self, start_tokens, seq_len=256, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, **kwargs): - device = start_tokens.device - was_training = self.net.training - num_dims = len(start_tokens.shape) - - if num_dims == 1: - start_tokens = start_tokens[None, :] - - b, t = start_tokens.shape - - self.net.eval() - out = start_tokens - mask = kwargs.pop('mask', None) - if mask is None: - mask = torch.full_like(out, True, dtype=torch.bool, device=out.device) - - for _ in range(seq_len): - x = out[:, -self.max_seq_len:] - mask = mask[:, -self.max_seq_len:] - # print('arw:',out.shape) - logits = self.net(x, mask=mask, **kwargs)[:, -1, :] - - if filter_logits_fn in {top_k, top_p}: - filtered_logits = filter_logits_fn(logits, thres=filter_thres) - probs = F.softmax(filtered_logits / temperature, dim=-1) - - sample = torch.multinomial(probs, 1) - - out = torch.cat((out, sample), dim=-1) - mask = F.pad(mask, (0, 1), value=True) - - if eos_token is not None and (torch.cumsum(out == eos_token, 1)[:, -1] >= 1).all(): - break - - out = out[:, t:] - - if num_dims == 1: - out = out.squeeze(0) - - self.net.train(was_training) - return out - - -class CustomVisionTransformer(VisionTransformer): - def __init__(self, img_size=224, patch_size=16, *args, **kwargs): - super(CustomVisionTransformer, self).__init__(img_size=img_size, patch_size=patch_size, *args, **kwargs) - self.height, self.width = img_size - self.patch_size = patch_size - - def forward_features(self, x): - B, c, h, w = x.shape - x = self.patch_embed(x) - - cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_tokens, x), dim=1) - h, w = h//self.patch_size, w//self.patch_size - pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w) - pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long() - x += self.pos_embed[:, pos_emb_ind] - #x = x + self.pos_embed - x = self.pos_drop(x) - - for blk in self.blocks: - x = blk(x) - - x = self.norm(x) - return x - - -class Model(nn.Module): - """ViT encoder, transformer decoder architecture""" - - def __init__(self, encoder: CustomVisionTransformer, decoder: CustomARWrapper, args, temp: float = .333): - super().__init__() - self.encoder = encoder - self.decoder = decoder - self.bos_token = args.bos_token - self.eos_token = args.eos_token - self.max_seq_len = args.max_seq_len - self.temperature = temp - - @torch.no_grad() - def forward(self, x: torch.Tensor): - device = x.device - encoded = self.encoder(x.to(device)) - dec = self.decoder.generate(torch.LongTensor([self.bos_token]*len(x))[:, None].to(device), self.max_seq_len, - eos_token=self.eos_token, context=encoded, temperature=self.temperature) - return dec - - -def get_model(args, training=False): - backbone = ResNetV2( - layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels, - preact=False, stem_type='same', conv_layer=StdConv2dSame) - min_patch_size = 2**(len(args.backbone_layers)+1) - - def embed_layer(**x): - ps = x.pop('patch_size', min_patch_size) - assert ps % min_patch_size == 0 and ps >= min_patch_size, 'patch_size needs to be multiple of %i with current backbone configuration' % min_patch_size - return HybridEmbed(**x, patch_size=ps//min_patch_size, backbone=backbone) - - encoder = CustomVisionTransformer(img_size=(args.max_height, args.max_width), - patch_size=args.patch_size, - in_chans=args.channels, - num_classes=0, - embed_dim=args.dim, - depth=args.encoder_depth, - num_heads=args.heads, - embed_layer=embed_layer - ).to(args.device) - - decoder = CustomARWrapper( - TransformerWrapper( - num_tokens=args.num_tokens, - max_seq_len=args.max_seq_len, - attn_layers=Decoder( - dim=args.dim, - depth=args.num_layers, - heads=args.heads, - **args.decoder_args - )), - pad_value=args.pad_token - ).to(args.device) - if 'wandb' in args and args.wandb: - import wandb - wandb.watch((encoder, decoder.net.attn_layers)) - model = Model(encoder, decoder, args) - if training: - # check if largest batch can be handled by system - try: - batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize - for _ in range(5): - im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float() - seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long() - decoder(seq, context=encoder(im)).sum().backward() - except RuntimeError: - raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize."%(batchsize, args.max_height, args.max_width)) - model.zero_grad() - torch.cuda.empty_cache() - del im, seq - return model diff --git a/pix2tex/models/__init__.py b/pix2tex/models/__init__.py new file mode 100644 index 0000000..90f60fd --- /dev/null +++ b/pix2tex/models/__init__.py @@ -0,0 +1 @@ +from .utils import * \ No newline at end of file diff --git a/pix2tex/models/hybrid.py b/pix2tex/models/hybrid.py new file mode 100644 index 0000000..25f5616 --- /dev/null +++ b/pix2tex/models/hybrid.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn + +from timm.models.vision_transformer import VisionTransformer +from timm.models.vision_transformer_hybrid import HybridEmbed +from timm.models.resnetv2 import ResNetV2 +from timm.models.layers import StdConv2dSame +from einops import repeat + +class CustomVisionTransformer(VisionTransformer): + def __init__(self, img_size=224, patch_size=16, *args, **kwargs): + super(CustomVisionTransformer, self).__init__(img_size=img_size, patch_size=patch_size, *args, **kwargs) + self.height, self.width = img_size + self.patch_size = patch_size + + def forward_features(self, x): + B, c, h, w = x.shape + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + h, w = h//self.patch_size, w//self.patch_size + pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w) + pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long() + x += self.pos_embed[:, pos_emb_ind] + #x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x + + +def get_encoder(args): + backbone = ResNetV2( + layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels, + preact=False, stem_type='same', conv_layer=StdConv2dSame) + min_patch_size = 2**(len(args.backbone_layers)+1) + + def embed_layer(**x): + ps = x.pop('patch_size', min_patch_size) + assert ps % min_patch_size == 0 and ps >= min_patch_size, 'patch_size needs to be multiple of %i with current backbone configuration' % min_patch_size + return HybridEmbed(**x, patch_size=ps//min_patch_size, backbone=backbone) + + encoder = CustomVisionTransformer(img_size=(args.max_height, args.max_width), + patch_size=args.patch_size, + in_chans=args.channels, + num_classes=0, + embed_dim=args.dim, + depth=args.encoder_depth, + num_heads=args.heads, + embed_layer=embed_layer + ) + return encoder \ No newline at end of file diff --git a/pix2tex/models/transformer.py b/pix2tex/models/transformer.py new file mode 100644 index 0000000..cfb2d30 --- /dev/null +++ b/pix2tex/models/transformer.py @@ -0,0 +1,66 @@ +import torch +import torch.nn.functional as F +from x_transformers.autoregressive_wrapper import AutoregressiveWrapper, top_k, top_p +from x_transformers import TransformerWrapper, Decoder + + +class CustomARWrapper(AutoregressiveWrapper): + def __init__(self, *args, **kwargs): + super(CustomARWrapper, self).__init__(*args, **kwargs) + + @torch.no_grad() + def generate(self, start_tokens, seq_len=256, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, **kwargs): + device = start_tokens.device + was_training = self.net.training + num_dims = len(start_tokens.shape) + + if num_dims == 1: + start_tokens = start_tokens[None, :] + + b, t = start_tokens.shape + + self.net.eval() + out = start_tokens + mask = kwargs.pop('mask', None) + if mask is None: + mask = torch.full_like(out, True, dtype=torch.bool, device=out.device) + + for _ in range(seq_len): + x = out[:, -self.max_seq_len:] + mask = mask[:, -self.max_seq_len:] + # print('arw:',out.shape) + logits = self.net(x, mask=mask, **kwargs)[:, -1, :] + + if filter_logits_fn in {top_k, top_p}: + filtered_logits = filter_logits_fn(logits, thres=filter_thres) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + sample = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + mask = F.pad(mask, (0, 1), value=True) + + if eos_token is not None and (torch.cumsum(out == eos_token, 1)[:, -1] >= 1).all(): + break + + out = out[:, t:] + + if num_dims == 1: + out = out.squeeze(0) + + self.net.train(was_training) + return out + + +def get_decoder(args): + return CustomARWrapper( + TransformerWrapper( + num_tokens=args.num_tokens, + max_seq_len=args.max_seq_len, + attn_layers=Decoder( + dim=args.dim, + depth=args.num_layers, + heads=args.heads, + **args.decoder_args + )), + pad_value=args.pad_token) diff --git a/pix2tex/models/utils.py b/pix2tex/models/utils.py new file mode 100644 index 0000000..d1c3599 --- /dev/null +++ b/pix2tex/models/utils.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn + +from . import hybrid +from . import vit +from . import transformer + + +class Model(nn.Module): + def __init__(self, encoder, decoder, args): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.args = args + + def data_parallel(self, x: torch.Tensor, device_ids, output_device=None, **kwargs): + if not device_ids or len(device_ids) == 1: + return self(x, **kwargs) + if output_device is None: + output_device = device_ids[0] + replicas = nn.parallel.replicate(self, device_ids) + inputs = nn.parallel.scatter(x, device_ids) # Slices tensors into approximately equal chunks and distributes them across given GPUs. + kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors. + replicas = replicas[:len(inputs)] + kwargs = kwargs[:len(inputs)] + outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs) + return nn.parallel.gather(outputs, output_device).mean() + + def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs): + encoded = self.encoder(x) + out = self.decoder(tgt_seq, context=encoded, **kwargs) + return out + + @torch.no_grad() + def generate(self, x: torch.Tensor, temperature: float = 0.25): + return self.decoder.generate((torch.LongTensor([self.args.bos_token]*len(x))[:, None]).to(x.device), self.args.max_seq_len, + eos_token=self.args.eos_token, context=self.encoder(x), temperature=temperature) + + +def get_model(args): + if args.encoder_structure.lower() == 'vit': + encoder = vit.get_encoder(args) + elif args.encoder_structure.lower() == 'hybrid': + encoder = hybrid.get_encoder(args) + else: + raise NotImplementedError('Encoder structure "%s" not supported.' % args.encoder_structure) + decoder = transformer.get_decoder(args) + encoder.to(args.device) + decoder.to(args.device) + model = Model(encoder, decoder, args) + if args.wandb: + import wandb + wandb.watch(model) + + return model diff --git a/pix2tex/models/vit.py b/pix2tex/models/vit.py new file mode 100644 index 0000000..80f47d2 --- /dev/null +++ b/pix2tex/models/vit.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn + +from x_transformers import Encoder +from einops import rearrange, repeat + + +class ViTransformerWrapper(nn.Module): + def __init__( + self, + *, + max_width, + max_height, + patch_size, + attn_layers, + channels=1, + num_classes=None, + dropout=0., + emb_dropout=0. + ): + super().__init__() + assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' + assert max_width % patch_size == 0 and max_height % patch_size == 0, 'image dimensions must be divisible by the patch size' + dim = attn_layers.dim + num_patches = (max_width // patch_size)*(max_height // patch_size) + patch_dim = channels * patch_size ** 2 + + self.patch_size = patch_size + self.max_width = max_width + self.max_height = max_height + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.patch_to_embedding = nn.Linear(patch_dim, dim) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + #self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None + + def forward(self, img, **kwargs): + p = self.patch_size + + x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) + x = self.patch_to_embedding(x) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) + x = torch.cat((cls_tokens, x), dim=1) + h, w = torch.tensor(img.shape[2:])//p + pos_emb_ind = repeat(torch.arange(h)*(self.max_width//p-w), 'h -> (h w)', w=w)+torch.arange(h*w) + pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long() + x += self.pos_embedding[:, pos_emb_ind] + x = self.dropout(x) + + x = self.attn_layers(x, **kwargs) + x = self.norm(x) + + return x + + +def get_encoder(args): + return ViTransformerWrapper( + max_width=args.max_width, + max_height=args.max_height, + channels=args.channels, + patch_size=args.patch_size, + emb_dropout=args.get('emb_dropout', 0), + attn_layers=Encoder( + dim=args.dim, + depth=args.num_layers, + heads=args.heads, + ) + ) diff --git a/pix2tex/train.py b/pix2tex/train.py index baba77d..5d6be47 100644 --- a/pix2tex/train.py +++ b/pix2tex/train.py @@ -8,10 +8,27 @@ from munch import Munch from tqdm.auto import tqdm import wandb - +import torch.nn as nn from pix2tex.eval import evaluate from pix2tex.models import get_model -from pix2tex.utils import * +# from pix2tex.utils import * +from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler + + +def gpu_memory_check(model, args): + # check if largest batch can be handled by system + try: + batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize + for _ in range(5): + im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float() + seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long() + loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq) + loss.sum().backward() + except RuntimeError: + raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width)) + model.zero_grad() + torch.cuda.empty_cache() + del im, seq def train(args): @@ -22,16 +39,16 @@ def train(args): valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True) valdataloader.update(**valargs) device = args.device - model = get_model(args, training=True) + model = get_model(args) + gpu_memory_check(model, args) if args.load_chkpt is not None: model.load_state_dict(torch.load(args.load_chkpt, map_location=device)) - encoder, decoder = model.encoder, model.decoder - + max_bleu, max_token_acc = 0, 0 out_path = os.path.join(args.model_path, args.name) os.makedirs(out_path, exist_ok=True) - def save_models(e): - torch.save(model.state_dict(), os.path.join(out_path, '%s_e%02d.pth' % (args.name, e+1))) + def save_models(e, step=0): + torch.save(model.state_dict(), os.path.join(out_path, '%s_e%02d_step%02d.pth' % (args.name, e+1, step))) yaml.dump(dict(args), open(os.path.join(out_path, 'config.yaml'), 'w+')) opt = get_optimizer(args.optimizer)(model.parameters(), args.lr, betas=args.betas) @@ -40,6 +57,7 @@ def save_models(e): microbatch = args.get('micro_batchsize', -1) if microbatch == -1: microbatch = args.batchsize + try: for e in range(args.epoch, args.epochs): args.epoch = e @@ -50,9 +68,8 @@ def save_models(e): total_loss = 0 for j in range(0, len(im), microbatch): tgt_seq, tgt_mask = seq['input_ids'][j:j+microbatch].to(device), seq['attention_mask'][j:j+microbatch].bool().to(device) - encoded = encoder(im[j:j+microbatch].to(device)) - loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)*microbatch/args.batchsize - loss.backward() + loss = model.data_parallel(im[j:j+microbatch].to(device), device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*microbatch/args.batchsize + loss.backward() # data parallism loss is a vector total_loss += loss.item() torch.nn.utils.clip_grad_norm_(model.parameters(), 1) opt.step() @@ -61,9 +78,12 @@ def save_models(e): if args.wandb: wandb.log({'train/loss': total_loss}) if (i+1+len(dataloader)*e) % args.sample_freq == 0: - evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val') + bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val') + if bleu_score > max_bleu and token_accuracy > max_token_acc: + max_bleu, max_token_acc = bleu_score, token_accuracy + save_models(e, step=i) if (e+1) % args.save_freq == 0: - save_models(e) + save_models(e, step=len(dataloader)) if args.wandb: wandb.log({'train/epoch': e+1}) except KeyboardInterrupt: @@ -92,4 +112,5 @@ def save_models(e): if not parsed_args.resume: args.id = wandb.util.generate_id() wandb.init(config=dict(args), resume='allow', name=args.name, id=args.id) + args = Munch(wandb.config) train(args) diff --git a/pix2tex/utils/utils.py b/pix2tex/utils/utils.py index b99948d..07d4708 100644 --- a/pix2tex/utils/utils.py +++ b/pix2tex/utils/utils.py @@ -53,7 +53,7 @@ def parse_args(args, **kwargs) -> Munch: args = Munch({'epoch': 0}, **args) kwargs = Munch({'no_cuda': False, 'debug': False}, **kwargs) args.wandb = not kwargs.debug and not args.debug - args.device = 'cuda' if torch.cuda.is_available() and not kwargs.no_cuda else 'cpu' + args.device = get_device(args, kwargs) args.max_dimensions = [args.max_width, args.max_height] args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)] if 'decoder_args' not in args or args.decoder_args is None: @@ -61,6 +61,17 @@ def parse_args(args, **kwargs) -> Munch: return args +def get_device(args, kwargs): + device = 'cpu' + available_gpus = torch.cuda.device_count() + args.gpu_devices = args.gpu_devices if args.get('gpu_devices', False) else range(available_gpus) + if available_gpus > 0 and not kwargs.no_cuda: + device = 'cuda:%d' % args.gpu_devices[0] if args.gpu_devices else 0 + assert available_gpus >= len(args.gpu_devices), "Available %d gpu, but specified gpu %s." % (available_gpus, ','.join(map(str, args.gpu_devices))) + assert max(args.gpu_devices) < available_gpus, "legal gpu_devices should in [%s], received [%s]" % (','.join(map(str, range(available_gpus))),','.join(map(str, args.gpu_devices))) + return device + + def token2str(tokens, tokenizer) -> list: if len(tokens.shape) == 1: tokens = tokens[None, :]