diff --git a/notebooks/LaTeX_OCR_test.ipynb b/notebooks/LaTeX_OCR_test.ipynb index 5620fb6..a1c800f 100644 --- a/notebooks/LaTeX_OCR_test.ipynb +++ b/notebooks/LaTeX_OCR_test.ipynb @@ -1,10 +1,23 @@ { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "LaTeX OCR test.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "aaAqi3wku23I" - }, "source": [ "# LaTeX OCR\n", "In this colab you can convert an image of an equation into LaTeX code.\n", @@ -14,7 +27,10 @@ "Next, execute the cell below and upload the image(s).\n", "\n", "> Note: You can probably also run this project locally and with a GUI. Follow the steps on [GitHub](https://github.com/lukas-blecher/LaTeX-OCR)" - ] + ], + "metadata": { + "id": "aaAqi3wku23I" + } }, { "cell_type": "code", @@ -55,46 +71,32 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "CjrR3O07u3uH" - }, - "outputs": [], "source": [ "imgs = upload_files()\n", "predictions = []\n", "for name, f in imgs:\n", " img = Image.open(f)\n", - " math = model.generate(img)\n", + " math = model(img)\n", " print(math)\n", " predictions.append('\\\\mathrm{%s} & \\\\displaystyle{%s}'%(name, math))\n", "Math(table%'\\\\\\\\'.join(predictions))" - ] + ], + "metadata": { + "id": "CjrR3O07u3uH" + }, + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, + "source": [ + "" + ], "metadata": { "id": "ZqCH-4XoCkMO" }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "LaTeX OCR test.ipynb", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" + "execution_count": null, + "outputs": [] } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ] +} \ No newline at end of file diff --git a/pix2tex/api/app.py b/pix2tex/api/app.py index da6dee8..bb0045d 100644 --- a/pix2tex/api/app.py +++ b/pix2tex/api/app.py @@ -45,7 +45,7 @@ async def predict(file: UploadFile = File(...)) -> str: """ global model image = Image.open(file.file) - return model.generate(image) + return model(image) @app.post('/bytes/') @@ -61,4 +61,4 @@ async def predict_from_bytes(file: bytes = File(...)) -> str: # , size: str = F global model #size = tuple(int(a) for a in size.split(',')) image = Image.open(BytesIO(file)) - return model.generate(image, resize=False) + return model(image, resize=False) diff --git a/pix2tex/cli.py b/pix2tex/cli.py index 81ecd89..d42339f 100644 --- a/pix2tex/cli.py +++ b/pix2tex/cli.py @@ -75,6 +75,7 @@ def __init__(self, arguments=None): download_checkpoints() self.model = get_model(self.args) self.model.load_state_dict(torch.load(self.args.checkpoint, map_location=self.args.device)) + self.model.eval() if 'image_resizer.pth' in os.listdir(os.path.dirname(self.args.checkpoint)) and not arguments.no_resize: self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(self.args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05, @@ -123,13 +124,8 @@ def __call__(self, img=None, resize=True) -> str: t = test_transform(image=img)['image'][:1].unsqueeze(0) im = t.to(self.args.device) - with torch.no_grad(): - self.model.eval() - device = self.args.device - encoded = self.model.encoder(im.to(device)) - dec = self.model.decoder.generate(torch.LongTensor([self.args.bos_token])[:, None].to(device), self.args.max_seq_len, - eos_token=self.args.eos_token, context=encoded.detach(), temperature=self.args.get('temperature', .25)) - pred = post_process(token2str(dec, self.tokenizer)[0]) + dec = self.model.generate(im.to(self.args.device), temperature=self.args.get('temperature', .25)) + pred = post_process(token2str(dec, self.tokenizer)[0]) try: clipboard.copy(pred) except: @@ -220,7 +216,7 @@ def main(): img = ImageGrab.grabclipboard() except: pass - pred = model.generate(img) + pred = model(img) output_prediction(pred, arguments) except KeyboardInterrupt: pass diff --git a/pix2tex/eval.py b/pix2tex/eval.py index c53ea53..8742988 100644 --- a/pix2tex/eval.py +++ b/pix2tex/eval.py @@ -50,10 +50,8 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i for i, (seq, im) in pbar: if seq is None or im is None: continue - encoded = model.encoder(im.to(device)) #loss = decoder(tgt_seq, mask=tgt_mask, context=encoded) - dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len, - eos_token=args.pad_token, context=encoded, temperature=args.get('temperature', .2)) + dec = model.generate(im.to(device), temperature=args.get('temperature', .2)) pred = detokenize(dec, dataset.tokenizer) truth = detokenize(seq['input_ids'], dataset.tokenizer) bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth])) diff --git a/pix2tex/models/utils.py b/pix2tex/models/utils.py index 9cd3433..934ee35 100644 --- a/pix2tex/models/utils.py +++ b/pix2tex/models/utils.py @@ -13,15 +13,28 @@ def __init__(self, encoder, decoder, args): self.decoder = decoder self.args = args - def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs): + def data_parallel(self, x: torch.Tensor, device_ids, output_device=None, **kwargs): + if not device_ids or len(device_ids) == 1: + return self(x, **kwargs) + if output_device is None: + output_device = device_ids[0] + replicas = nn.parallel.replicate(self, device_ids) + inputs = nn.parallel.scatter(x, device_ids) # Slices tensors into approximately equal chunks and distributes them across given GPUs. + kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors. + replicas = replicas[:len(inputs)] + kwargs = kwargs[:len(inputs)] + outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs) + return nn.parallel.gather(outputs, output_device).mean() + + def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs): encoded = self.encoder(x) out = self.decoder(tgt_seq, context=encoded, **kwargs) return out @torch.no_grad() - def generate(self, x: torch.Tensor): - return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device), - self.args.max_seq_len, eos_token=self.args.eos_token, context=self.encoder(x)) + def generate(self, x: torch.Tensor, temperature: float = 0.25): + return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device), self.args.max_seq_len, + eos_token=self.args.eos_token, context=self.encoder(x), temperature=temperature) def get_model(args): diff --git a/pix2tex/train.py b/pix2tex/train.py index cf710b5..120e379 100644 --- a/pix2tex/train.py +++ b/pix2tex/train.py @@ -15,20 +15,6 @@ from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler -def data_parallel(module, x:torch.Tensor, device_ids, output_device=None, **kwargs): - if not device_ids or len(device_ids) == 1: - return module(x, **kwargs) - if output_device is None: - output_device = device_ids[0] - replicas = nn.parallel.replicate(module, device_ids) - inputs = nn.parallel.scatter(x, device_ids) # Slices tensors into approximately equal chunks and distributes them across given GPUs. - kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors. - replicas = replicas[:len(inputs)] - kwargs = kwargs[:len(inputs)] - outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs) - return nn.parallel.gather(outputs, output_device) - - def gpu_memory_check(model, args): # check if largest batch can be handled by system try: @@ -36,10 +22,7 @@ def gpu_memory_check(model, args): for _ in range(5): im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float() seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long() - # model.decoder(seq, context=model.encoder(im)).sum().backward() - # encoded = data_parallel(model.encoder, inputs=im, device_ids=args.gpu_devices) - # loss = data_parallel(model.decoder, inputs=seq, device_ids=args.gpu_devices, context=encoded) - loss = data_parallel(model, im, device_ids=args.gpu_devices, tgt_seq=seq) + loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq) loss.sum().backward() except RuntimeError: raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width)) @@ -60,7 +43,6 @@ def train(args): gpu_memory_check(model, args) if args.load_chkpt is not None: model.load_state_dict(torch.load(args.load_chkpt, map_location=device)) - encoder, decoder = model.encoder, model.decoder max_bleu, max_token_acc = 0, 0 out_path = os.path.join(args.model_path, args.name) os.makedirs(out_path, exist_ok=True) @@ -86,14 +68,9 @@ def save_models(e, step=0): total_loss = 0 for j in range(0, len(im), microbatch): tgt_seq, tgt_mask = seq['input_ids'][j:j+microbatch].to(device), seq['attention_mask'][j:j+microbatch].bool().to(device) - # encoded = encoder(im[j:j+microbatch].to(device)) - # encoded = data_parallel(encoder, inputs=im[j:j+microbatch].to(device), device_ids=args.gpu_devices) - # loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)*microbatch/args.batchsize - # loss = data_parallel(module=decoder, inputs=tgt_seq, device_ids=args.gpu_devices, mask=tgt_mask, context=encoded)*microbatch/args.batchsize - # loss.backward() - loss = data_parallel(model,im[j:j+microbatch].to(device), device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*microbatch/args.batchsize - loss.mean().backward() # data parallism loss is a vector - total_loss += loss.mean().item() + loss = model.data_parallel(im[j:j+microbatch].to(device), device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*microbatch/args.batchsize + loss.backward() # data parallism loss is a vector + total_loss += loss.item() torch.nn.utils.clip_grad_norm_(model.parameters(), 1) opt.step() scheduler.step()