Skip to content

Commit 4056b86

Browse files
committed
fix some more bugs introduced by package
1 parent 243908c commit 4056b86

File tree

7 files changed

+45
-25
lines changed

7 files changed

+45
-25
lines changed

README.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,13 @@ Always double check the result carefully. You can try to redo the prediction wit
3838
1. First we need to combine the images with their ground truth labels. I wrote a dataset class (which needs further improving) that saves the relative paths to the images with the LaTeX code they were rendered with. To generate the dataset pickle file run
3939

4040
```
41-
python -m pix2tex.dataset.dataset --equations path_to_textfile --images path_to_images --tokenizer dataset/tokenizer.json --out dataset.pkl
41+
python -m pix2tex.dataset.dataset --equations path_to_textfile --images path_to_images --out dataset.pkl
4242
```
43+
To use your own tokenizer pass it via `--tokenizer` (See below).
4344

4445
You can find my generated training data on the [Google Drive](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO) as well (formulae.zip - images, math.txt - labels). Repeat the step for the validation and test data. All use the same label text file.
4546

46-
2. Edit the `data` (and `valdata`) entry in the config file to the newly generated `.pkl` file. Change other hyperparameters if you want to. See `settings/config.yaml` for a template.
47+
2. Edit the `data` (and `valdata`) entry in the config file to the newly generated `.pkl` file. Change other hyperparameters if you want to. See `pix2tex/model/settings/config.yaml` for a template.
4748
3. Now for the actual training run
4849
```
4950
python -m pix2tex.train --config path_to_config_file

pix2tex/dataset/dataset.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from tempfile import tempdir
12
import albumentations as alb
23
from albumentations.pytorch import ToTensorV2
34
import torch
@@ -15,6 +16,7 @@
1516
from transformers import PreTrainedTokenizerFast
1617
from tqdm.auto import tqdm
1718

19+
from pix2tex.utils.utils import in_model_path
1820

1921
train_transform = alb.Compose(
2022
[
@@ -188,6 +190,11 @@ def load(self, filename, args=[]):
188190
Args:
189191
filename (str): Path to dataset
190192
"""
193+
if not os.path.exists(filename):
194+
with in_model_path():
195+
tempf = os.path.join('..', filename)
196+
if os.path.exists(tempf):
197+
filename = os.path.realpath(tempf)
191198
with open(filename, 'rb') as file:
192199
x = pickle.load(file)
193200
return x
@@ -201,7 +208,7 @@ def combine(self, x):
201208
for key in x.data.keys():
202209
if key in self.data.keys():
203210
self.data[key].extend(x.data[key])
204-
self.data[key]=list(set(self.data[key]))
211+
self.data[key] = list(set(self.data[key]))
205212
else:
206213
self.data[key] = x.data[key]
207214
self._get_size()
@@ -230,6 +237,12 @@ def update(self, **kwargs):
230237
if self.min_dimensions[0] <= k[0] <= self.max_dimensions[0] and self.min_dimensions[1] <= k[1] <= self.max_dimensions[1]:
231238
temp[k] = self.data[k]
232239
self.data = temp
240+
if 'tokenizer' in kwargs:
241+
tokenizer_file = kwargs['tokenizer']
242+
if not os.path.exists(tokenizer_file):
243+
with in_model_path():
244+
tokenizer_file = os.path.realpath(tokenizer_file)
245+
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
233246
self._get_size()
234247
iter(self)
235248

@@ -251,13 +264,16 @@ def generate_tokenizer(equations, output, vocab_size):
251264
parser.add_argument('-i', '--images', type=str, nargs='+', default=None, help='Image folders')
252265
parser.add_argument('-e', '--equations', type=str, nargs='+', default=None, help='equations text files')
253266
parser.add_argument('-t', '--tokenizer', default=None, help='Pretrained tokenizer file')
254-
parser.add_argument('-o', '--out', required=True, help='output file')
267+
parser.add_argument('-o', '--out', type=str, required=True, help='output file')
255268
parser.add_argument('-s', '--vocab-size', default=8000, type=int, help='vocabulary size when training a tokenizer')
256269
args = parser.parse_args()
257-
if args.images is None and args.equations is not None and args.tokenizer is None:
270+
if args.tokenizer is None:
271+
with in_model_path():
272+
args.tokenizer = os.path.realpath(os.path.join('dataset', 'tokenizer.json'))
273+
if args.images is None and args.equations is not None:
258274
print('Generate tokenizer')
259275
generate_tokenizer(args.equations, args.out, args.vocab_size)
260-
elif args.images is not None and args.equations is not None and args.tokenizer is not None:
276+
elif args.images is not None and args.equations is not None:
261277
print('Generate dataset')
262278
dataset = None
263279
for images, equations in zip(args.images, args.equations):

pix2tex/eval.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from pix2tex.dataset.dataset import Im2LatexDataset
2-
import os
3-
import sys
42
import argparse
53
import logging
64
import yaml
@@ -90,8 +88,8 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
9088

9189
if __name__ == '__main__':
9290
parser = argparse.ArgumentParser(description='Test model')
93-
parser.add_argument('--config', default='settings/config.yaml', help='path to yaml config file', type=argparse.FileType('r'))
94-
parser.add_argument('-c', '--checkpoint', default='checkpoints/weights.pth', type=str, help='path to model checkpoint')
91+
parser.add_argument('--config', default=None, help='path to yaml config file', type=str)
92+
parser.add_argument('-c', '--checkpoint', default=None, type=str, help='path to model checkpoint')
9593
parser.add_argument('-d', '--data', default='dataset/data/val.pkl', type=str, help='Path to Dataset pkl file')
9694
parser.add_argument('--no-cuda', action='store_true', help='Use CPU')
9795
parser.add_argument('-b', '--batchsize', type=int, default=10, help='Batch size')
@@ -100,7 +98,10 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
10098
parser.add_argument('-n', '--num-batches', type=int, default=None, help='how many batches to evaluate on. Defaults to None (all)')
10199

102100
parsed_args = parser.parse_args()
103-
with parsed_args.config as f:
101+
if parsed_args.config is None:
102+
with in_model_path():
103+
parsed_args.config = os.path.realpath('settings/config.yaml')
104+
with open(parsed_args.config, 'r') as f:
104105
params = yaml.load(f, Loader=yaml.FullLoader)
105106
args = parse_args(Munch(params))
106107
args.testbatchsize = parsed_args.batchsize
@@ -109,8 +110,10 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
109110
logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING)
110111
seed_everything(args.seed if 'seed' in args else 42)
111112
model = get_model(args)
112-
if parsed_args.checkpoint is not None:
113-
model.load_state_dict(torch.load(parsed_args.checkpoint, args.device))
113+
if parsed_args.checkpoint is None:
114+
with in_model_path():
115+
parsed_args.checkpoint = os.path.realpath('checkpoints/weights.pth')
116+
model.load_state_dict(torch.load(parsed_args.checkpoint, args.device))
114117
dataset = Im2LatexDataset().load(parsed_args.data)
115118
valargs = args.copy()
116119
valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True)

pix2tex/train.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
from pix2tex.dataset.dataset import Im2LatexDataset
22
import os
3-
import sys
43
import argparse
54
import logging
65
import yaml
76

8-
import numpy as np
97
import torch
10-
import torch.optim as optim
11-
import torch.nn as nn
128
from munch import Munch
139
from tqdm.auto import tqdm
1410
import wandb
@@ -72,14 +68,15 @@ def save_models(e):
7268

7369
if __name__ == '__main__':
7470
parser = argparse.ArgumentParser(description='Train model')
75-
parser.add_argument('--config', default='settings/debug.yaml', help='path to yaml config file', type=argparse.FileType('r'))
76-
parser.add_argument('-d', '--data', default='dataset/data/train.pkl', type=str, help='Path to Dataset pkl file')
71+
parser.add_argument('--config', default=None, help='path to yaml config file', type=str)
7772
parser.add_argument('--no_cuda', action='store_true', help='Use CPU')
7873
parser.add_argument('--debug', action='store_true', help='DEBUG')
7974
parser.add_argument('--resume', help='path to checkpoint folder', action='store_true')
80-
8175
parsed_args = parser.parse_args()
82-
with parsed_args.config as f:
76+
if parsed_args.config is None:
77+
with in_model_path():
78+
parsed_args.config = os.path.realpath('settings/debug.yaml')
79+
with open(parsed_args.config, 'r') as f:
8380
params = yaml.load(f, Loader=yaml.FullLoader)
8481
args = parse_args(Munch(params), **vars(parsed_args))
8582
logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING)

pix2tex/train_resizer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,18 @@ def train_epoch(sched=None):
135135

136136
if __name__ == '__main__':
137137
parser = argparse.ArgumentParser(description='Train size classification model')
138-
parser.add_argument('--config', default='settings/debug.yaml', help='path to yaml config file', type=argparse.FileType('r'))
138+
parser.add_argument('--config', default=None, help='path to yaml config file', type=str)
139139
parser.add_argument('--no_cuda', action='store_true', help='Use CPU')
140140
parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
141141
parser.add_argument('--resume', help='path to checkpoint folder', type=str, default='')
142142
parser.add_argument('--out', type=str, default='checkpoints/image_resizer.pth', help='output destination for trained model')
143143
parser.add_argument('--num_epochs', type=int, default=10, help='number of epochs to train')
144144
parser.add_argument('--batchsize', type=int, default=10)
145145
parsed_args = parser.parse_args()
146-
with parsed_args.config as f:
146+
if parsed_args.config is None:
147+
with in_model_path():
148+
parsed_args.config = os.path.realpath('settings/debug.yaml')
149+
with open(parsed_args.config, 'r') as f:
147150
params = yaml.load(f, Loader=yaml.FullLoader)
148151
args = parse_args(Munch(params), **vars(parsed_args))
149152
args.update(**vars(parsed_args))

pix2tex/utils/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def num_model_params(model):
159159
def in_model_path():
160160
from importlib.resources import path
161161
with path('pix2tex', 'model') as model_path:
162-
os.chdir(model_path)
163162
saved = os.getcwd()
163+
os.chdir(model_path)
164164
try:
165165
yield
166166
finally:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
setuptools.setup(
1111
name='pix2tex',
12-
version='0.0.6',
12+
version='0.0.8',
1313
description="pix2tex: Using a ViT to convert images of equations into LaTeX code.",
1414
long_description=long_description,
1515
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)