Skip to content

Commit 7acd5eb

Browse files
committed
improve latex extraction from arxiv sources
- fix bugs in demacro and cover more cases - add option to save arxiv sources lukas-blecher#133
1 parent 78fe5b2 commit 7acd5eb

File tree

6 files changed

+160
-82
lines changed

6 files changed

+160
-82
lines changed

pix2tex/dataset/__init__.py

-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +0,0 @@
1-
import pix2tex.dataset.arxiv
2-
import pix2tex.dataset.extract_latex
3-
import pix2tex.dataset.latex2png
4-
import pix2tex.dataset.render
5-
import pix2tex.dataset.scraping
6-
import pix2tex.dataset.dataset

pix2tex/dataset/arxiv.py

+69-44
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# modified from https://github.com/soskek/arxiv_leaks
22

33
import argparse
4-
import json
4+
import subprocess
55
import os
66
import glob
77
import re
@@ -10,7 +10,6 @@
1010
import logging
1111
import tarfile
1212
import tempfile
13-
import chardet
1413
import logging
1514
import requests
1615
import urllib.request
@@ -22,7 +21,7 @@
2221

2322
# logging.getLogger().setLevel(logging.INFO)
2423
arxiv_id = re.compile(r'(?<!\d)(\d{4}\.\d{5})(?!\d)')
25-
arxiv_base = 'https://arxiv.org/e-print/'
24+
arxiv_base = 'https://export.arxiv.org/e-print/'
2625

2726

2827
def get_all_arxiv_ids(text):
@@ -48,7 +47,7 @@ def download(url, dir_path='./'):
4847
return 0
4948

5049

51-
def read_tex_files(file_path):
50+
def read_tex_files(file_path, demacro=False):
5251
tex = ''
5352
try:
5453
with tempfile.TemporaryDirectory() as tempdir:
@@ -59,50 +58,59 @@ def read_tex_files(file_path):
5958
texfiles = [os.path.abspath(x) for x in glob.glob(os.path.join(tempdir, '**', '*.tex'), recursive=True)]
6059
except tarfile.ReadError as e:
6160
texfiles = [file_path] # [os.path.join(tempdir, file_path+'.tex')]
61+
if demacro:
62+
ret = subprocess.run(['de-macro', *texfiles], cwd=tempdir, capture_output=True)
63+
if ret.returncode == 0:
64+
texfiles = glob.glob(os.path.join(tempdir, '**', '*-clean.tex'), recursive=True)
6265
for texfile in texfiles:
6366
try:
64-
tex += open(texfile, 'r', encoding=chardet.detect(open(texfile, 'br').readline())['encoding']).read()
65-
except UnicodeDecodeError:
67+
ct = open(texfile, 'r', encoding='utf-8').read()
68+
tex += ct
69+
except UnicodeDecodeError as e:
70+
logging.debug(e)
6671
pass
67-
tex = unfold(convert(tex))
6872
except Exception as e:
6973
logging.debug('Could not read %s: %s' % (file_path, str(e)))
70-
pass
71-
# remove comments
72-
return re.sub(r'(?<!\\)%.*\n', '', tex)
74+
raise e
75+
tex = pydemacro(tex)
76+
return tex
7377

7478

7579
def download_paper(arxiv_id, dir_path='./'):
7680
url = arxiv_base + arxiv_id
7781
return download(url, dir_path)
7882

7983

80-
def read_paper(targz_path, delete=True):
84+
def read_paper(targz_path, delete=False, demacro=False):
8185
paper = ''
8286
if targz_path != 0:
83-
paper = read_tex_files(targz_path)
87+
paper = read_tex_files(targz_path, demacro=demacro)
8488
if delete:
8589
os.remove(targz_path)
8690
return paper
8791

8892

89-
def parse_arxiv(id):
90-
tempdir = tempfile.gettempdir()
91-
text = read_paper(download_paper(id, tempdir))
92-
#print(text, file=open('paper.tex', 'w'))
93-
#linked = list(set([l for l in re.findall(arxiv_id, text)]))
93+
def parse_arxiv(id, save=None, demacro=True):
94+
if save is None:
95+
dir = tempfile.gettempdir()
96+
else:
97+
dir = save
98+
text = read_paper(download_paper(id, dir), delete=save is None, demacro=demacro)
9499

95100
return find_math(text, wiki=False), []
96101

97102

98103
if __name__ == '__main__':
99104
# logging.getLogger().setLevel(logging.DEBUG)
100105
parser = argparse.ArgumentParser(description='Extract math from arxiv')
101-
parser.add_argument('-m', '--mode', default='top100', choices=['top100', 'ids', 'dir'],
102-
help='Where to extract code from. top100: current 100 arxiv papers, id: specific arxiv ids. \
103-
Usage: `python arxiv.py -m id id001 id002`, dir: a folder full of .tar.gz files. Usage: `python arxiv.py -m dir directory`')
106+
parser.add_argument('-m', '--mode', default='top100', choices=['top', 'ids', 'dirs'],
107+
help='Where to extract code from. top: current 100 arxiv papers (-m top int for any other number of papers), id: specific arxiv ids. \
108+
Usage: `python arxiv.py -m id id001 id002`, dirs: a folder full of .tar.gz files. Usage: `python arxiv.py -m dir directory`')
104109
parser.add_argument(nargs='*', dest='args', default=[])
105110
parser.add_argument('-o', '--out', default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data'), help='output directory')
111+
parser.add_argument('-d', '--demacro', dest='demacro', action='store_true',
112+
help='Deprecated - Use de-macro (Slows down extraction, may but improves quality). Install https://www.ctan.org/pkg/de-macro')
113+
parser.add_argument('-s', '--save', default=None, type=str, help='When downloading files from arxiv. Where to save the .tar.gz files. Default: Only temporary')
106114
args = parser.parse_args()
107115
if '.' in args.out:
108116
args.out = os.path.dirname(args.out)
@@ -111,30 +119,47 @@ def parse_arxiv(id):
111119
skip = open(skips, 'r', encoding='utf-8').read().split('\n')
112120
else:
113121
skip = []
114-
if args.mode == 'ids':
115-
visited, math = recursive_search(parse_arxiv, args.args, skip=skip, unit='paper')
116-
elif args.mode == 'top100':
117-
url = 'https://arxiv.org/list/physics/pastweek?skip=0&show=100' #'https://arxiv.org/list/hep-th/2203?skip=0&show=100'
118-
ids = get_all_arxiv_ids(requests.get(url).text)
119-
math, visited = [], ids
120-
for id in tqdm(ids):
121-
m, _ = parse_arxiv(id)
122-
math.extend(m)
123-
elif args.mode == 'dir':
124-
dirs = os.listdir(args.args[0])
125-
math, visited = [], []
126-
for f in tqdm(dirs):
127-
try:
128-
text = read_paper(os.path.join(args.args[0], f), False)
129-
math.extend(find_math(text, wiki=False))
130-
visited.append(os.path.basename(f))
131-
except Exception as e:
132-
logging.debug(e)
133-
pass
134-
else:
135-
raise NotImplementedError
136-
print('\n'.join(math))
137-
sys.exit(0)
122+
if args.save is not None:
123+
os.makedirs(args.save, exist_ok=True)
124+
try:
125+
if args.mode == 'ids':
126+
visited, math = recursive_search(parse_arxiv, args.args, skip=skip, unit='paper', save=args.save, demacro=args.demacro)
127+
elif args.mode == 'top':
128+
num = 100 if len(args.args) == 0 else int(args.args[0])
129+
url = 'https://arxiv.org/list/physics/pastweek?skip=0&show=%i' % num # 'https://arxiv.org/list/hep-th/2203?skip=0&show=100'
130+
ids = get_all_arxiv_ids(requests.get(url).text)
131+
math, visited = [], ids
132+
for id in tqdm(ids):
133+
try:
134+
m, _ = parse_arxiv(id, save=args.save, demacro=args.demacro)
135+
math.extend(m)
136+
except ValueError:
137+
pass
138+
elif args.mode == 'dirs':
139+
files = []
140+
for folder in args.args:
141+
files.extend([os.path.join(folder, p) for p in os.listdir(folder)])
142+
math, visited = [], []
143+
for f in tqdm(files):
144+
try:
145+
text = read_paper(f, delete=False, demacro=args.demacro)
146+
math.extend(find_math(text, wiki=False))
147+
visited.append(os.path.basename(f))
148+
except DemacroError as e:
149+
logging.debug(f + str(e))
150+
pass
151+
except KeyboardInterrupt:
152+
break
153+
except Exception as e:
154+
logging.debug(e)
155+
raise e
156+
else:
157+
raise NotImplementedError
158+
except KeyboardInterrupt:
159+
pass
160+
print('Found %i instances of math latex code' % len(math))
161+
# print('\n'.join(math))
162+
# sys.exit(0)
138163
for l, name in zip([visited, math], ['visited_arxiv.txt', 'math_arxiv.txt']):
139164
f = os.path.join(args.out, name)
140165
if not os.path.exists(f):

pix2tex/dataset/demacro.py

+81-24
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22

33
import argparse
44
import re
5+
import logging
6+
from collections import Counter
7+
import time
58
from pix2tex.dataset.extract_latex import remove_labels
69

710

11+
class DemacroError(Exception):
12+
pass
13+
14+
815
def main():
916
args = parse_command_line()
1017
data = read(args.input)
11-
data = convert(data)
12-
data = unfold(data)
18+
data = pydemacro(data)
1319
if args.output is not None:
1420
write(args.output, data)
1521
else:
@@ -28,16 +34,6 @@ def read(path):
2834
return handle.read()
2935

3036

31-
def convert(data):
32-
return re.sub(
33-
r'((?:\\(?:expandafter|global|long|outer|protected)'
34-
r'(?: +|\r?\n *)?)*)?'
35-
r'\\def *(\\[a-zA-Z]+) *(?:#+([0-9]))*\{',
36-
replace,
37-
data,
38-
)
39-
40-
4137
def bracket_replace(string: str) -> str:
4238
'''
4339
replaces all layered brackets with special symbols
@@ -66,7 +62,9 @@ def sweep(t, cmds):
6662
nargs = int(c[1][1]) if c[1] != r'' else 0
6763
optional = c[2] != r''
6864
if nargs == 0:
69-
t = re.sub(r'\\%s([\W_^\d])' % c[0], r'%s\1' % c[-1].replace('\\', r'\\'), t)
65+
num_matches += len(re.findall(r'\\%s([\W_^\dĊ])' % c[0], t))
66+
if num_matches > 0:
67+
t = re.sub(r'\\%s([\W_^\dĊ])' % c[0], r'%s\1' % c[-1].replace('\\', r'\\'), t)
7068
else:
7169
matches = re.findall(r'(\\%s(?:\[(.+?)\])?' % c[0]+r'{(.+?)}'*(nargs-(1 if optional else 0))+r')', t)
7270
num_matches += len(matches)
@@ -81,18 +79,68 @@ def sweep(t, cmds):
8179

8280

8381
def unfold(t):
84-
t = remove_labels(t).replace('\n', 'Ċ')
85-
86-
cmds = re.findall(r'\\(?:re)?newcommand\*?{\\(.+?)}\s*(\[\d\])?(\[.+?\])?{(.+?)}Ċ', t)
82+
#t = queue.get()
83+
t = t.replace('\n', 'Ċ')
84+
t = bracket_replace(t)
85+
commands_pattern = r'\\(?:re)?newcommand\*?{\\(.+?)}[\sĊ]*(\[\d\])?[\sĊ]*(\[.+?\])?[\sĊ]*{(.*?)}\s*(?:Ċ|\\)'
86+
cmds = re.findall(commands_pattern, t)
87+
t = re.sub(r'(?<!\\)'+commands_pattern, 'Ċ', t)
8788
cmds = sorted(cmds, key=lambda x: len(x[0]))
88-
for _ in range(10):
89-
# check for up to 10 nested commands
90-
t = bracket_replace(t)
91-
t, N = sweep(t, cmds)
92-
t = undo_bracket_replace(t)
93-
if N == 0:
94-
break
95-
return t.replace('Ċ', '\n')
89+
cmd_names = Counter([c[0] for c in cmds])
90+
for i in reversed(range(len(cmds))):
91+
if cmd_names[cmds[i][0]] > 1:
92+
# something went wrong here. No multiple definitions allowed
93+
del cmds[i]
94+
elif '\\newcommand' in cmds[i][-1]:
95+
logging.debug("Command recognition pattern didn't work properly. %s" % (undo_bracket_replace(cmds[i][-1])))
96+
del cmds[i]
97+
start = time.time()
98+
try:
99+
for i in range(10):
100+
# check for up to 10 nested commands
101+
if i > 0:
102+
t = bracket_replace(t)
103+
t, N = sweep(t, cmds)
104+
if time.time()-start > 5: # not optimal. more sophisticated methods didnt work or are slow
105+
raise TimeoutError
106+
t = undo_bracket_replace(t)
107+
if N == 0 or i == 9:
108+
#print("Needed %i iterations to demacro" % (i+1))
109+
break
110+
elif N > 4000:
111+
raise ValueError("Too many matches. Processing would take too long.")
112+
except ValueError:
113+
pass
114+
except TimeoutError:
115+
pass
116+
except re.error as e:
117+
raise DemacroError(e)
118+
t = remove_labels(t.replace('Ċ', '\n'))
119+
# queue.put(t)
120+
return t
121+
122+
123+
def pydemacro(t):
124+
return unfold(convert(re.sub('\n+', '\n', re.sub(r'(?<!\\)%.*\n', '\n', t))))
125+
126+
127+
def pydemacro2(t, timeout=15):
128+
q = multiprocessing.Queue(len(t))
129+
text = convert(re.sub('\n+', '\n', re.sub(r'(?<!\\)%.*\n', '\n', t)))
130+
q.put(text)
131+
p = multiprocessing.Process(target=unfold, args=(q,))
132+
p.start()
133+
# print("main")
134+
# return q.get(timeout=10)
135+
# interrupt after fixed time
136+
p.join(timeout)
137+
if p.is_alive():
138+
logging.debug('Timeout: killing demacro process')
139+
p.terminate()
140+
p.join()
141+
return text
142+
else:
143+
return q.get()
96144

97145

98146
def replace(match):
@@ -120,6 +168,15 @@ def replace(match):
120168
return result
121169

122170

171+
def convert(data):
172+
data = re.sub(
173+
r'((?:\\(?:expandafter|global|long|outer|protected)(?:\s+|\r?\n\s*)?)*)?\\def\s*(\\[a-zA-Z]+)\s*(?:#+([0-9]))*\{',
174+
replace,
175+
data,
176+
)
177+
return re.sub(r'\\let\s*(\\[a-zA-Z]+)\s*=?\s*(\\?\w+)*', r'\\newcommand*{\1}{\2}\n', data)
178+
179+
123180
def write(path, data):
124181
with open(path, mode='w') as handle:
125182
handle.write(data)

pix2tex/dataset/extract_latex.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
displaymath = re.compile(r'(\\displaystyle)(.{%i,%i}?)(\}(?:<|"))' % (1, MAX_CHARS))
1111
outer_whitespace = re.compile(
1212
r'^\\,|\\,$|^~|~$|^\\ |\\ $|^\\thinspace|\\thinspace$|^\\!|\\!$|^\\:|\\:$|^\\;|\\;$|^\\enspace|\\enspace$|^\\quad|\\quad$|^\\qquad|\\qquad$|^\\hspace{[a-zA-Z0-9]+}|\\hspace{[a-zA-Z0-9]+}$|^\\hfill|\\hfill$')
13-
label_names = [re.compile(r'\\%s\s?\{(.*?)\}' % s) for s in ['ref', 'cite', 'label', 'caption', 'eqref']]
13+
label_names = [re.compile(r'\\%s\s?\{(.*?)\}' % s) for s in ['ref', 'cite', 'label', 'eqref']]
1414

1515
def check_brackets(s):
1616
a = []

pix2tex/dataset/scraping.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def parse_wiki(url):
2727

2828

2929
# recursive search
30-
def recursive_search(parser, seeds, depth=2, skip=[], unit='links', base_url=None):
30+
def recursive_search(parser, seeds, depth=2, skip=[], unit='links', base_url=None, **kwargs):
3131
visited, links = set(skip), set(seeds)
3232
math = []
3333
try:
@@ -39,9 +39,9 @@ def recursive_search(parser, seeds, depth=2, skip=[], unit='links', base_url=No
3939
if not link in visited:
4040
t_bar.set_description('searching %s' % (link))
4141
if base_url:
42-
m, l = parser(base_url+link)
42+
m, l = parser(base_url+link, **kwargs)
4343
else:
44-
m, l = parser(link)
44+
m, l = parser(link, **kwargs)
4545
# check if we got any math from this wiki page and
4646
# if not terminate the tree
4747
if len(m) > 0:
@@ -72,9 +72,12 @@ def recursive_wiki(seeds, depth=4, skip=[]):
7272
url = [sys.argv[1]]
7373
else:
7474
url = ['https://en.wikipedia.org/wiki/Mathematics', 'https://en.wikipedia.org/wiki/Physics']
75-
visited, math = recursive_wiki(url)
75+
try:
76+
visited, math = recursive_wiki(url)
77+
except KeyboardInterrupt:
78+
pass
7679
for l, name in zip([visited, math], ['visited_wiki.txt', 'math_wiki.txt']):
77-
f = open(os.path.join(sys.path[0], 'dataset', 'data', name), 'a', encoding='utf-8')
80+
f = open(os.path.join(sys.path[0], 'data', name), 'a', encoding='utf-8')
7881
for element in l:
7982
f.write(element)
8083
f.write('\n')

setup.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
setuptools.setup(
1111
name='pix2tex',
12-
version='0.0.14',
12+
version='0.0.15',
1313
description='pix2tex: Using a ViT to convert images of equations into LaTeX code.',
1414
long_description=long_description,
1515
long_description_content_type='text/markdown',
@@ -45,7 +45,6 @@
4545
'PyYAML>=5.4.1',
4646
'pandas>=1.0.0',
4747
'timm',
48-
'chardet>=3.0.4',
4948
'python-Levenshtein>=0.12.2',
5049
'torchtext>=0.6.0',
5150
'albumentations>=0.5.2',

0 commit comments

Comments
 (0)