Skip to content

Commit 97df469

Browse files
committed
dataset updates
1 parent 69da9f4 commit 97df469

13 files changed

+302
-66
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The goal of this project is to create a learning based system that takes an imag
1111
The `pix2tex.py` file offers a fast way to get the model prediction of an image. First you need to copy the formula image into the clipboard memory for example by using a snipping tool (on Windows built in `Win`+`Shift`+`S`). Next just call the script with `python pix2tex.py`. It will print out the predicted Latex code for that image and also copy it into your clipboard.
1212

1313
## Data
14-
We need paired data for the network to learn. Luckily there is a lot of LaTeX code on the internet, e.g. [wikipedia](www.wikipedia.org), [arXiv](www.arxiv.org). We also use the formulae from the [im2latex-170k](https://www.kaggle.com/rvente/im2latex170k) dataset.
14+
We need paired data for the network to learn. Luckily there is a lot of LaTeX code on the internet, e.g. [wikipedia](www.wikipedia.org), [arXiv](www.arxiv.org). We also use the formulae from the [im2latex-100k](https://zenodo.org/record/56198#.V2px0jXT6eA) dataset.
1515

1616
### Fonts
1717
Latin Modern Math, GFSNeohellenicMath.otf, Asana Math, XITS Math, Cambria Math

dataset/arxiv.py

+66-30
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import glob
77
import re
88
import sys
9+
import argparse
10+
import logging
11+
import shutil
912
import subprocess
1013
import tarfile
1114
import tempfile
@@ -17,9 +20,11 @@
1720
try:
1821
from extract_latex import *
1922
from scraping import *
23+
from demacro import *
2024
except:
2125
from dataset.extract_latex import *
2226
from dataset.scraping import *
27+
from dataset.demacro import *
2328

2429
# logging.getLogger().setLevel(logging.INFO)
2530
arxiv_id = re.compile(r'(?<!\d)(\d{4}\.\d{5})(?!\d)')
@@ -49,73 +54,104 @@ def download(url, dir_path='./'):
4954
return 0
5055

5156

52-
def read_tex_files(file_path):
57+
def read_tex_files(file_path, demacro=True):
5358
tex = ''
5459
try:
5560
with tempfile.TemporaryDirectory() as tempdir:
56-
tf = tarfile.open(file_path, 'r')
57-
tf.extractall(tempdir)
58-
tf.close()
59-
texfiles = [os.path.abspath(x) for x in glob.glob(os.path.join(tempdir, '**', '*.tex'), recursive=True)]
60-
# de-macro
61-
ret = subprocess.run(['de-macro', *texfiles], cwd=tempdir, capture_output=True)
62-
if ret.returncode == 0:
63-
texfiles = glob.glob(os.path.join(tempdir, '**', '*-clean.tex'), recursive=True)
61+
try:
62+
tf = tarfile.open(file_path, 'r')
63+
tf.extractall(tempdir)
64+
tf.close()
65+
texfiles = [os.path.abspath(x) for x in glob.glob(os.path.join(tempdir, '**', '*.tex'), recursive=True)]
66+
# de-macro
67+
if demacro:
68+
ret = subprocess.run(['de-macro', *texfiles], cwd=tempdir, capture_output=True)
69+
if ret.returncode == 0:
70+
texfiles = glob.glob(os.path.join(tempdir, '**', '*-clean.tex'), recursive=True)
71+
except tarfile.ReadError as e:
72+
texfiles = [file_path] # [os.path.join(tempdir, file_path+'.tex')]
73+
#shutil.move(file_path, texfiles[0])
74+
6475
for texfile in texfiles:
6576
try:
66-
tex += open(texfile, 'r', encoding=chardet.detect(open(texfile, 'br').readline())['encoding']).read()
77+
tex += open(texfile, 'r', encoding=chardet.detect(open(texfile, 'br').readline())['encoding']).read()
6778
except UnicodeDecodeError:
6879
pass
69-
70-
except tarfile.ReadError:
71-
try:
72-
tex += open(file_path, 'r', encoding=chardet.detect(open(file_path, 'br').readline())['encoding']).read()
73-
except Exception as e:
74-
logging.info('Could not read %s: %s' % (file_path, str(e)))
75-
pass
80+
tex = unfold(convert(tex))
81+
except Exception as e:
82+
logging.debug('Could not read %s: %s' % (file_path, str(e)))
83+
pass
7684
# remove comments
7785
return re.sub(r'(?<!\\)%.*\n', '', tex)
7886

7987

80-
def read_paper(arxiv_id, dir_path='./'):
88+
def download_paper(arxiv_id, dir_path='./'):
8189
url = arxiv_base + arxiv_id
82-
targz_path = download(url, dir_path)
90+
return download(url, dir_path)
91+
92+
93+
def read_paper(targz_path, delete=True, demacro=True):
8394
paper = ''
8495
if targz_path != 0:
85-
paper = read_tex_files(targz_path)
86-
os.remove(targz_path)
96+
paper = read_tex_files(targz_path, demacro)
97+
if delete:
98+
os.remove(targz_path)
8799
return paper
88100

89101

90-
def parse_arxiv(id):
102+
def parse_arxiv(id, demacro=True):
91103
tempdir = tempfile.gettempdir()
92-
text = read_paper(id, tempdir)
104+
text = read_paper(download_paper(id, tempdir), demacro=demacro)
93105
#print(text, file=open('paper.tex', 'w'))
94106
#linked = list(set([l for l in re.findall(arxiv_id, text)]))
95107

96108
return find_math(text, wiki=False), []
97109

98110

99111
if __name__ == '__main__':
100-
skips = os.path.join(sys.path[0], 'dataset', 'data', 'visited_arxiv.txt')
112+
parser = argparse.ArgumentParser(description='Extract math from arxiv')
113+
parser.add_argument('-m', '--mode', default='top100', choices=['top100', 'id', 'dir'],
114+
help='Where to extract code from. top100: current 100 arxiv papers, id: specific arxiv ids. \
115+
Usage: `python arxiv.py -m id id001 id002`, dir: a folder full of .tar.gz files. Usage: `python arxiv.py -m dir directory`')
116+
parser.add_argument(nargs='+', dest='args', default=[])
117+
parser.add_argument('-o', '--out', default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data'), help='output directory')
118+
parser.add_argument('-d', '--no-demacro', dest='demacro', action='store_false', help='Use de-macro (Slows down extraction but improves quality)')
119+
args = parser.parse_args()
120+
if '.' in args.out:
121+
args.out = os.path.dirname(args.out)
122+
skips = os.path.join(args.out, 'visited_arxiv.txt')
101123
if os.path.exists(skips):
102124
skip = open(skips, 'r', encoding='utf-8').read().split('\n')
103125
else:
104126
skip = []
105-
if len(sys.argv) > 1:
106-
arxiv_ids = sys.argv[1:]
107-
visited, math = recursive_search(parse_arxiv, arxiv_ids, skip=skip, unit='paper')
108-
109-
else:
127+
if args.mode == 'ids':
128+
visited, math = recursive_search(parse_arxiv, args.args, skip=skip, unit='paper')
129+
elif args.mode == 'top100':
110130
url = 'https://arxiv.org/list/hep-th/2012?skip=0&show=100' # https://arxiv.org/list/hep-th/2012?skip=0&show=100
111131
ids = get_all_arxiv_ids(requests.get(url).text)
112132
math, visited = [], ids
113133
for id in tqdm(ids):
114134
m, _ = parse_arxiv(id)
115135
math.extend(m)
136+
elif args.mode == 'dir':
137+
dirs = os.listdir(args.args[0])
138+
math, visited = [], []
139+
for f in tqdm(dirs):
140+
try:
141+
text = read_paper(os.path.join(args.args[0], f), False, args.demacro)
142+
math.extend(find_math(text, wiki=False))
143+
visited.append(os.path.basename(f))
144+
except Exception as e:
145+
logging.debug(e)
146+
pass
147+
else:
148+
raise NotImplementedError
116149

117150
for l, name in zip([visited, math], ['visited_arxiv.txt', 'math_arxiv.txt']):
118-
f = open(os.path.join(sys.path[0], 'dataset', 'data', name), 'a', encoding='utf-8')
151+
f = os.path.join(args.out, name)
152+
if not os.path.exists(f):
153+
open(f, 'w').write('')
154+
f = open(f, 'a', encoding='utf-8')
119155
for element in l:
120156
f.write(element)
121157
f.write('\n')

dataset/dataset.py

+59-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,49 @@
1212
from collections import defaultdict
1313
import pickle
1414
from PIL import Image
15+
import cv2
1516
from transformers import PreTrainedTokenizerFast
17+
from tqdm.auto import tqdm
18+
import albumentations as alb
19+
from albumentations.pytorch import ToTensorV2
20+
21+
22+
class AugWrap:
23+
def __init__(self, aug):
24+
self.aug = aug
25+
26+
def __call__(self, image):
27+
return self.aug(image=image)['image'][:1] # /255
28+
29+
30+
train_transform = AugWrap(
31+
alb.Compose(
32+
[
33+
alb.Compose(
34+
[alb.ShiftScaleRotate(shift_limit=0, scale_limit=(-.15, 0), rotate_limit=1, border_mode=0, interpolation=3,
35+
value=[255, 255, 255], p=1),
36+
alb.GridDistortion(distort_limit=0.1, border_mode=0, interpolation=3, value=[255, 255, 255], p=.5)], p=.15),
37+
alb.InvertImg(p=.15),
38+
alb.RGBShift(r_shift_limit=15, g_shift_limit=15,
39+
b_shift_limit=15, p=0.3),
40+
alb.GaussNoise(10, p=.2),
41+
alb.RandomBrightnessContrast(.05, (-.2, 0), True, p=0.2),
42+
alb.JpegCompression(95, p=.5),
43+
alb.ToGray(always_apply=True),
44+
alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
45+
# alb.Sharpen()
46+
ToTensorV2(),
47+
]
48+
))
49+
test_transform = AugWrap(
50+
alb.Compose(
51+
[
52+
alb.ToGray(always_apply=True),
53+
alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
54+
# alb.Sharpen()
55+
ToTensorV2(),
56+
]
57+
))
1658

1759

1860
class Im2LatexDataset:
@@ -26,8 +68,9 @@ class Im2LatexDataset:
2668
pad_token_id = 0
2769
bos_token_id = 1
2870
eos_token_id = 2
71+
transform = train_transform
2972

30-
def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, batchsize=16, max_dimensions=(1024, 512), keep_smaller_batches=False):
73+
def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, batchsize=16, max_dimensions=(1024, 512), pad=False, keep_smaller_batches=False, test=False):
3174
"""Generates a torch dataset from pairs of `equations` and `images`.
3275
3376
Args:
@@ -37,37 +80,41 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba
3780
shuffle (bool, opitonal): Defaults to True.
3881
batchsize (int, optional): Defaults to 16.
3982
max_dimensions (tuple(int, int), optional): Maximal dimensions the model can handle
83+
pad (bool): Pad the images to `max_dimensions`. Defaults to False.
4084
keep_smaller_batches (bool): Whether to also return batches with smaller size than `batchsize`. Defaults to False.
85+
test (bool): Whether to use the test transformation or not. Defaults to False.
4186
"""
4287

4388
if images is not None and equations is not None:
4489
assert tokenizer is not None
45-
self.images = [path.replace('\\', '/') for path in glob.glob(join(images, '*.png'))]
90+
self.images = [path.replace('\\', '/') for path in glob.glob(join(images, '*.png'))]
4691
self.sample_size = len(self.images)
4792
eqs = open(equations, 'r').read().split('\n')
4893
self.indices = [int(os.path.basename(img).split('.')[0]) for img in self.images]
4994
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer)
5095
self.shuffle = shuffle
5196
self.batchsize = batchsize
5297
self.max_dimensions = max_dimensions
98+
self.pad = pad
5399
self.keep_smaller_batches = keep_smaller_batches
100+
self.test = test
54101
self.data = defaultdict(lambda: [])
55102
# check the image dimension for every image and group them together
56-
for i, im in enumerate(self.images):
103+
for i, im in tqdm(enumerate(self.images), total=len(self.images)):
57104
width, height = imagesize.get(im)
58105
if width <= max_dimensions[0] and height <= max_dimensions[1]:
59106
self.data[(width, height)].append((eqs[self.indices[i]], im))
60107
self.data = dict(self.data)
61108
self._get_size()
62109

63-
self.transform = transforms.Compose([transforms.PILToTensor()]) # , transforms.Normalize([200],[255/2]),transforms.RandomPerspective(fill=0)])
64110
iter(self)
65111

66112
def __len__(self):
67113
return self.size
68114

69115
def __iter__(self):
70116
self.i = 0
117+
self.transform = test_transform if self.test else train_transform
71118
self.pairs = []
72119
for k in self.data:
73120
info = np.array(self.data[k], dtype=object)
@@ -105,12 +152,17 @@ def prepare_data(self, batch):
105152
eqs, ims = batch.T
106153
images = []
107154
for path in list(ims):
108-
images.append(self.transform(Image.open(path)))
155+
im = cv2.imread(path)
156+
if im is None:
157+
print(path, 'not found!')
158+
continue
159+
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
160+
images.append(self.transform(im))
109161
tok = self.tokenizer(list(eqs), return_token_type_ids=False)
110162
# pad with bos and eos token
111163
for k, p in zip(tok, [[self.bos_token_id, self.eos_token_id], [1, 1]]):
112164
tok[k] = pad_sequence([torch.LongTensor([p[0]]+x+[p[1]]) for x in tok[k]], batch_first=True, padding_value=self.pad_token_id)
113-
images = torch.cat(images).float().unsqueeze(1)/255
165+
images = torch.cat(images).float().unsqueeze(1)
114166
if self.pad:
115167
h, w = images.shape[2:]
116168
images = F.pad(images, (0, self.max_dimensions[0]-w, 0, self.max_dimensions[1]-h), value=1)
@@ -142,7 +194,7 @@ def save(self, filename):
142194
pickle.dump(self, file)
143195

144196
def update(self, **kwargs):
145-
for k in ['batchsize', 'shuffle', 'pad', 'keep_smaller_batches']:
197+
for k in ['batchsize', 'shuffle', 'pad', 'keep_smaller_batches', 'test']:
146198
if k in kwargs:
147199
setattr(self, k, kwargs[k])
148200
if 'max_dimensions' in kwargs:

dataset/demacro.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# modified from https://tex.stackexchange.com/a/521639
2+
3+
import argparse
4+
import re
5+
6+
7+
def main():
8+
args = parse_command_line()
9+
data = read(args.input)
10+
data = convert(data)
11+
if args.demacro:
12+
data = unfold(data)
13+
write(args.output, data)
14+
15+
16+
def parse_command_line():
17+
parser = argparse.ArgumentParser(description='Replace \\def with \\newcommand where possible.')
18+
parser.add_argument('input', help='TeX input file with \\def')
19+
parser.add_argument('--output', '-o', required=True, help='TeX output file with \\newcommand')
20+
parser.add_argument('--demacro', action='store_true', help='replace all commands with their definition')
21+
22+
return parser.parse_args()
23+
24+
25+
def read(path):
26+
with open(path, mode='r') as handle:
27+
return handle.read()
28+
29+
30+
def convert(data):
31+
return re.sub(
32+
r'((?:\\(?:expandafter|global|long|outer|protected)'
33+
r'(?: +|\r?\n *)?)*)?'
34+
r'\\def *(\\[a-zA-Z]+) *(?:#+([0-9]))*\{',
35+
replace,
36+
data,
37+
)
38+
39+
40+
def unfold(t):
41+
cmds = re.findall(r'\\(?:re)?newcommand\*?{\\(.+?)}\s*(\[\d\])?(\[.+?\])?{(.+?)}\n', t)
42+
cmds = sorted(cmds, key=lambda x: len(x[0]))
43+
# print(cmds)
44+
for c in cmds:
45+
nargs = int(c[1][1]) if c[1] != r'' else 0
46+
# print(c)
47+
if nargs == 0:
48+
#t = t.replace(r'\\%s' % c[0], c[-1])
49+
t = re.sub(r'\\%s([\W_^\d])' % c[0], r'%s\1' % c[-1].replace('\\', r'\\'), t)
50+
else:
51+
matches = re.findall(r'(\\%s(?:\[(.+?)\])?' % c[0]+r'{(.+?)}'*(nargs-(1 if c[2] != r'' else 0))+r')', t)
52+
# print(matches)
53+
for i, m in enumerate(matches):
54+
r = c[-1]
55+
if m[1] == r'':
56+
matches[i] = (m[0], c[2][1:-1], *m[2:])
57+
for j in range(1, nargs+1):
58+
r = r.replace(r'#%i' % j, matches[i][j])
59+
t = t.replace(matches[i][0], r)
60+
return t
61+
62+
63+
def replace(match):
64+
prefix = match.group(1)
65+
if (
66+
prefix is not None and
67+
(
68+
'expandafter' in prefix or
69+
'global' in prefix or
70+
'outer' in prefix or
71+
'protected' in prefix
72+
)
73+
):
74+
return match.group(0)
75+
76+
result = r'\newcommand'
77+
if prefix is None or 'long' not in prefix:
78+
result += '*'
79+
80+
result += '{' + match.group(2) + '}'
81+
if match.lastindex == 3:
82+
result += '[' + match.group(3) + ']'
83+
84+
result += '{'
85+
return result
86+
87+
88+
def write(path, data):
89+
with open(path, mode='w') as handle:
90+
handle.write(data)
91+
92+
print('=> File written: {0}'.format(path))
93+
94+
95+
if __name__ == '__main__':
96+
main()

0 commit comments

Comments
 (0)