Skip to content

Commit b3069e3

Browse files
committed
auto download latest checkpoints
1 parent b5217d2 commit b3069e3

File tree

6 files changed

+40
-11
lines changed

6 files changed

+40
-11
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ notebooks/
133133
dataset/data/**
134134
wandb/
135135
checkpoints/**
136+
!checkpoints/*.py
136137
!**/.gitkeep
137138
.vscode
138139
.DS_Store

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ In order to render the math in many different fonts we use XeLaTeX, generate a
2323
## Using the model
2424
1. Download/Clone this repository
2525
2. For now you need to install the Python dependencies specified in `requirements.txt` (look [above](#Requirements))
26-
3. Download the `weights.pth` (and optionally `image_resizer.pth`) file from the [Releases](https://github.com/lukas-blecher/LaTeX-OCR/releases/latest)->Assets section and place it in the `checkpoints` directory
26+
3. The latest model checkpoint will be downloaded the first time the program is executed. Alternatively you can download the `weights.pth` (and optionally `image_resizer.pth`) file from the [Releases](https://github.com/lukas-blecher/LaTeX-OCR/releases/latest)->Assets section and place it in the `checkpoints` directory
2727

2828
Thanks to [@katie-lim](https://github.com/katie-lim), you can use a nice user interface as a quick way to get the model prediction. Just call the GUI with `python gui.py`. From here you can take a screenshot and the predicted latex code is rendered using [MathJax](https://www.mathjax.org/) and copied to your clipboard.
2929

checkpoints/.gitkeep

Whitespace-only changes.

checkpoints/get_latest_checkpoint.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import requests
2+
import os
3+
4+
url = 'https://github.com/lukas-blecher/LaTeX-OCR/releases/latest'
5+
6+
7+
def get_latest_tag():
8+
r = requests.get(url)
9+
tag = r.url.split('/')[-1]
10+
if tag == 'releases':
11+
return 'v0.0.1'
12+
return tag
13+
14+
15+
def download_checkpoints():
16+
tag = get_latest_tag()
17+
path = os.path.dirname(__file__)
18+
print('download weights', tag, 'to path', path)
19+
weights = 'https://github.com/lukas-blecher/LaTeX-OCR/releases/download/%s/weights.pth' % tag
20+
resizer = 'https://github.com/lukas-blecher/LaTeX-OCR/releases/download/%s/image_resizer.pth' % tag
21+
for url, name in zip([weights, resizer], ['weights.pth', 'resizer.pth']):
22+
r = requests.get(url, allow_redirects=True)
23+
open(os.path.join(path, name), "wb").write(r.content)
24+
25+
26+
if __name__ == '__main__':
27+
download_checkpoints()

models.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from x_transformers import *
6-
from x_transformers.autoregressive_wrapper import *
5+
# from x_transformers import *
6+
from x_transformers import TransformerWrapper, Decoder
7+
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper, top_k, top_p, entmax, ENTMAX_ALPHA
78
from timm.models.vision_transformer import VisionTransformer
89
from timm.models.vision_transformer_hybrid import HybridEmbed
910
from timm.models.resnetv2 import ResNetV2
@@ -16,7 +17,7 @@ def __init__(self, *args, **kwargs):
1617
super(CustomARWrapper, self).__init__(*args, **kwargs)
1718

1819
@torch.no_grad()
19-
def generate(self, start_tokens, seq_len, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, **kwargs):
20+
def forward(self, start_tokens, seq_len=256, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, **kwargs):
2021
device = start_tokens.device
2122
was_training = self.net.training
2223
num_dims = len(start_tokens.shape)
@@ -42,9 +43,6 @@ def generate(self, start_tokens, seq_len, eos_token=None, temperature=1., filter
4243
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
4344
probs = F.softmax(filtered_logits / temperature, dim=-1)
4445

45-
elif filter_logits_fn is entmax:
46-
probs = entmax(logits / temperature, alpha=ENTMAX_ALPHA, dim=-1)
47-
4846
sample = torch.multinomial(probs, 1)
4947

5048
out = torch.cat((out, sample), dim=-1)
@@ -150,6 +148,6 @@ def embed_layer(**x):
150148
seq = torch.randint(0, args.num_tokens, (args.batchsize, args.max_seq_len), device=args.device).long()
151149
decoder(seq, context=encoder(im)).sum().backward()
152150
model.zero_grad()
153-
torch.cuda.empty_cache()
151+
torch.cuda.empty_cache()
154152
del im, seq
155153
return model

pix2tex.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from dataset.latex2png import tex2pil
2222
from models import get_model
2323
from utils import *
24+
from checkpoints.get_latest_checkpoint import download_checkpoints
2425

2526
last_pic = None
2627

@@ -50,7 +51,8 @@ def initialize(arguments=None):
5051
args.update(**vars(arguments))
5152
args.wandb = False
5253
args.device = 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu'
53-
54+
if not os.path.exists(args.checkpoint):
55+
download_checkpoints()
5456
model = get_model(args)
5557
model.load_state_dict(torch.load(args.checkpoint, map_location=args.device))
5658

@@ -82,9 +84,10 @@ def call_model(args, model, image_resizer, tokenizer, img=None):
8284
if image_resizer is not None and not args.no_resize:
8385
with torch.no_grad():
8486
input_image = img.convert('RGB').copy()
85-
r, w = 1, input_image.size[0]
87+
r, w, h = 1, input_image.size[0], input_image.size[1]
8688
for _ in range(10):
87-
img = pad(minmax_size(input_image.resize((w, int(input_image.size[1]*r)), Image.BILINEAR if r > 1 else Image.LANCZOS), args.max_dimensions, args.min_dimensions))
89+
h = int(h * r) # height to resize
90+
img = pad(minmax_size(input_image.resize((w, h), Image.BILINEAR if r > 1 else Image.LANCZOS), args.max_dimensions, args.min_dimensions))
8891
t = test_transform(image=np.array(img.convert('RGB')))['image'][:1].unsqueeze(0)
8992
w = (image_resizer(t.to(args.device)).argmax(-1).item()+1)*32
9093
logging.info(r, img.size, (w, int(input_image.size[1]*r)))

0 commit comments

Comments
 (0)