Skip to content

Commit 279cb2f

Browse files
committed
Merge branch 'main' into api
2 parents f002a65 + 59903b1 commit 279cb2f

File tree

7 files changed

+106
-81
lines changed

7 files changed

+106
-81
lines changed

README.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ To run the model you need Python 3.7+
1212
Install the package `pix2tex`:
1313

1414
```
15-
pip install pix2tex
15+
pip install pix2tex[gui]
1616
```
1717

1818
Model checkpoints will be downloaded automatically.
@@ -73,7 +73,6 @@ In order to render the math in many different fonts we use XeLaTeX, generate a
7373
* [XeLaTeX](https://www.ctan.org/pkg/xetex)
7474
* [ImageMagick](https://imagemagick.org/) with [Ghostscript](https://www.ghostscript.com/index.html). (for converting pdf to png)
7575
* [Node.js](https://nodejs.org/) to run [KaTeX](https://github.com/KaTeX/KaTeX) (for normalizing Latex code)
76-
* [`de-macro`](https://www.ctan.org/pkg/de-macro) >= 1.4 (only for parsing arxiv papers)
7776
* Python 3.7+ & dependencies (specified in `setup.py`)
7877

7978
### Fonts

pix2tex/dataset/arxiv.py

+17-24
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,16 @@
88
import sys
99
import argparse
1010
import logging
11-
import shutil
12-
import subprocess
1311
import tarfile
1412
import tempfile
1513
import chardet
1614
import logging
1715
import requests
1816
import urllib.request
17+
from tqdm import tqdm
1918
from urllib.error import HTTPError
20-
from pix2tex.dataset.extract_latex import *
21-
from pix2tex.dataset.scraping import *
19+
from pix2tex.dataset.extract_latex import find_math
20+
from pix2tex.dataset.scraping import recursive_search
2221
from pix2tex.dataset.demacro import *
2322

2423
# logging.getLogger().setLevel(logging.INFO)
@@ -49,7 +48,7 @@ def download(url, dir_path='./'):
4948
return 0
5049

5150

52-
def read_tex_files(file_path, demacro=True):
51+
def read_tex_files(file_path):
5352
tex = ''
5453
try:
5554
with tempfile.TemporaryDirectory() as tempdir:
@@ -58,18 +57,11 @@ def read_tex_files(file_path, demacro=True):
5857
tf.extractall(tempdir)
5958
tf.close()
6059
texfiles = [os.path.abspath(x) for x in glob.glob(os.path.join(tempdir, '**', '*.tex'), recursive=True)]
61-
# de-macro
62-
if demacro:
63-
ret = subprocess.run(['de-macro', *texfiles], cwd=tempdir, capture_output=True)
64-
if ret.returncode == 0:
65-
texfiles = glob.glob(os.path.join(tempdir, '**', '*-clean.tex'), recursive=True)
6660
except tarfile.ReadError as e:
6761
texfiles = [file_path] # [os.path.join(tempdir, file_path+'.tex')]
68-
#shutil.move(file_path, texfiles[0])
69-
7062
for texfile in texfiles:
7163
try:
72-
tex += open(texfile, 'r', encoding=chardet.detect(open(texfile, 'br').readline())['encoding']).read()
64+
tex += open(texfile, 'r', encoding=chardet.detect(open(texfile, 'br').readline())['encoding']).read()
7365
except UnicodeDecodeError:
7466
pass
7567
tex = unfold(convert(tex))
@@ -85,32 +77,32 @@ def download_paper(arxiv_id, dir_path='./'):
8577
return download(url, dir_path)
8678

8779

88-
def read_paper(targz_path, delete=True, demacro=True):
80+
def read_paper(targz_path, delete=True):
8981
paper = ''
9082
if targz_path != 0:
91-
paper = read_tex_files(targz_path, demacro)
83+
paper = read_tex_files(targz_path)
9284
if delete:
9385
os.remove(targz_path)
9486
return paper
9587

9688

97-
def parse_arxiv(id, demacro=True):
89+
def parse_arxiv(id):
9890
tempdir = tempfile.gettempdir()
99-
text = read_paper(download_paper(id, tempdir), demacro=demacro)
91+
text = read_paper(download_paper(id, tempdir))
10092
#print(text, file=open('paper.tex', 'w'))
10193
#linked = list(set([l for l in re.findall(arxiv_id, text)]))
10294

10395
return find_math(text, wiki=False), []
10496

10597

10698
if __name__ == '__main__':
99+
# logging.getLogger().setLevel(logging.DEBUG)
107100
parser = argparse.ArgumentParser(description='Extract math from arxiv')
108-
parser.add_argument('-m', '--mode', default='top100', choices=['top100', 'id', 'dir'],
101+
parser.add_argument('-m', '--mode', default='top100', choices=['top100', 'ids', 'dir'],
109102
help='Where to extract code from. top100: current 100 arxiv papers, id: specific arxiv ids. \
110103
Usage: `python arxiv.py -m id id001 id002`, dir: a folder full of .tar.gz files. Usage: `python arxiv.py -m dir directory`')
111-
parser.add_argument(nargs='+', dest='args', default=[])
104+
parser.add_argument(nargs='*', dest='args', default=[])
112105
parser.add_argument('-o', '--out', default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data'), help='output directory')
113-
parser.add_argument('-d', '--no-demacro', dest='demacro', action='store_false', help='Use de-macro (Slows down extraction but improves quality)')
114106
args = parser.parse_args()
115107
if '.' in args.out:
116108
args.out = os.path.dirname(args.out)
@@ -122,7 +114,7 @@ def parse_arxiv(id, demacro=True):
122114
if args.mode == 'ids':
123115
visited, math = recursive_search(parse_arxiv, args.args, skip=skip, unit='paper')
124116
elif args.mode == 'top100':
125-
url = 'https://arxiv.org/list/hep-th/2012?skip=0&show=100' # https://arxiv.org/list/hep-th/2012?skip=0&show=100
117+
url = 'https://arxiv.org/list/physics/pastweek?skip=0&show=100' #'https://arxiv.org/list/hep-th/2203?skip=0&show=100'
126118
ids = get_all_arxiv_ids(requests.get(url).text)
127119
math, visited = [], ids
128120
for id in tqdm(ids):
@@ -133,15 +125,16 @@ def parse_arxiv(id, demacro=True):
133125
math, visited = [], []
134126
for f in tqdm(dirs):
135127
try:
136-
text = read_paper(os.path.join(args.args[0], f), False, args.demacro)
128+
text = read_paper(os.path.join(args.args[0], f), False)
137129
math.extend(find_math(text, wiki=False))
138-
visited.append(os.path.basename(f))
130+
visited.append(os.path.basename(f))
139131
except Exception as e:
140132
logging.debug(e)
141133
pass
142134
else:
143135
raise NotImplementedError
144-
136+
print('\n'.join(math))
137+
sys.exit(0)
145138
for l, name in zip([visited, math], ['visited_arxiv.txt', 'math_arxiv.txt']):
146139
f = os.path.join(args.out, name)
147140
if not os.path.exists(f):

pix2tex/dataset/demacro.py

+51-16
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,24 @@
22

33
import argparse
44
import re
5+
from pix2tex.dataset.extract_latex import remove_labels
56

67

78
def main():
89
args = parse_command_line()
910
data = read(args.input)
1011
data = convert(data)
11-
if args.demacro:
12-
data = unfold(data)
13-
write(args.output, data)
12+
data = unfold(data)
13+
if args.output is not None:
14+
write(args.output, data)
15+
else:
16+
print(data)
1417

1518

1619
def parse_command_line():
1720
parser = argparse.ArgumentParser(description='Replace \\def with \\newcommand where possible.')
1821
parser.add_argument('input', help='TeX input file with \\def')
19-
parser.add_argument('--output', '-o', required=True, help='TeX output file with \\newcommand')
20-
parser.add_argument('--demacro', action='store_true', help='replace all commands with their definition')
21-
22+
parser.add_argument('--output', '-o', default=None, help='TeX output file with \\newcommand')
2223
return parser.parse_args()
2324

2425

@@ -37,27 +38,61 @@ def convert(data):
3738
)
3839

3940

40-
def unfold(t):
41-
cmds = re.findall(r'\\(?:re)?newcommand\*?{\\(.+?)}\s*(\[\d\])?(\[.+?\])?{(.+?)}\n', t)
42-
cmds = sorted(cmds, key=lambda x: len(x[0]))
43-
# print(cmds)
41+
def bracket_replace(string: str) -> str:
42+
'''
43+
replaces all layered brackets with special symbols
44+
'''
45+
layer = 0
46+
out = list(string)
47+
for i, c in enumerate(out):
48+
if c == '{':
49+
if layer > 0:
50+
out[i] = 'Ḋ'
51+
layer += 1
52+
elif c == '}':
53+
layer -= 1
54+
if layer > 0:
55+
out[i] = 'Ḍ'
56+
return ''.join(out)
57+
58+
59+
def undo_bracket_replace(string):
60+
return string.replace('Ḋ', '{').replace('Ḍ', '}')
61+
62+
63+
def sweep(t, cmds):
64+
num_matches = 0
4465
for c in cmds:
4566
nargs = int(c[1][1]) if c[1] != r'' else 0
46-
# print(c)
67+
optional = c[2] != r''
4768
if nargs == 0:
48-
#t = t.replace(r'\\%s' % c[0], c[-1])
4969
t = re.sub(r'\\%s([\W_^\d])' % c[0], r'%s\1' % c[-1].replace('\\', r'\\'), t)
5070
else:
51-
matches = re.findall(r'(\\%s(?:\[(.+?)\])?' % c[0]+r'{(.+?)}'*(nargs-(1 if c[2] != r'' else 0))+r')', t)
52-
# print(matches)
71+
matches = re.findall(r'(\\%s(?:\[(.+?)\])?' % c[0]+r'{(.+?)}'*(nargs-(1 if optional else 0))+r')', t)
72+
num_matches += len(matches)
5373
for i, m in enumerate(matches):
5474
r = c[-1]
5575
if m[1] == r'':
5676
matches[i] = (m[0], c[2][1:-1], *m[2:])
5777
for j in range(1, nargs+1):
58-
r = r.replace(r'#%i' % j, matches[i][j])
78+
r = r.replace(r'#%i' % j, matches[i][j+int(not optional)])
5979
t = t.replace(matches[i][0], r)
60-
return t
80+
return t, num_matches
81+
82+
83+
def unfold(t):
84+
t = remove_labels(t).replace('\n', 'Ċ')
85+
86+
cmds = re.findall(r'\\(?:re)?newcommand\*?{\\(.+?)}\s*(\[\d\])?(\[.+?\])?{(.+?)}Ċ', t)
87+
cmds = sorted(cmds, key=lambda x: len(x[0]))
88+
for _ in range(10):
89+
# check for up to 10 nested commands
90+
t = bracket_replace(t)
91+
t, N = sweep(t, cmds)
92+
t = undo_bracket_replace(t)
93+
if N == 0:
94+
break
95+
return t.replace('Ċ', '\n')
6196

6297

6398
def replace(match):

pix2tex/dataset/extract_latex.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
displaymath = re.compile(r'(\\displaystyle)(.{%i,%i}?)(\}(?:<|"))' % (1, MAX_CHARS))
1111
outer_whitespace = re.compile(
1212
r'^\\,|\\,$|^~|~$|^\\ |\\ $|^\\thinspace|\\thinspace$|^\\!|\\!$|^\\:|\\:$|^\\;|\\;$|^\\enspace|\\enspace$|^\\quad|\\quad$|^\\qquad|\\qquad$|^\\hspace{[a-zA-Z0-9]+}|\\hspace{[a-zA-Z0-9]+}$|^\\hfill|\\hfill$')
13-
13+
label_names = [re.compile(r'\\%s\s?\{(.*?)\}' % s) for s in ['ref', 'cite', 'label', 'caption', 'eqref']]
1414

1515
def check_brackets(s):
1616
a = []
@@ -39,17 +39,18 @@ def check_brackets(s):
3939
else:
4040
return s
4141

42+
def remove_labels(string):
43+
for s in label_names:
44+
string = re.sub(s, '', string)
45+
return string
4246

4347
def clean_matches(matches, min_chars=MIN_CHARS):
44-
template = r'\\%s\s?\{(.*?)\}'
45-
sub = [re.compile(template % s) for s in ['ref', 'cite', 'label', 'caption']]
4648
faulty = []
4749
for i in range(len(matches)):
4850
if 'tikz' in matches[i]: # do not support tikz at the moment
4951
faulty.append(i)
5052
continue
51-
for s in sub:
52-
matches[i] = re.sub(s, '', matches[i])
53+
matches[i] = remove_labels(matches[i])
5354
matches[i] = matches[i].replace('\n', '').replace(r'\notag', '').replace(r'\nonumber', '')
5455
matches[i] = re.sub(outer_whitespace, '', matches[i])
5556
if len(matches[i]) < min_chars:

pix2tex/dataset/scraping.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import html
66
import requests
77
import re
8-
import tempfile
9-
from pix2tex.dataset.arxiv import *
10-
from pix2tex.dataset.extract_latex import *
8+
from pix2tex.dataset.extract_latex import find_math
119

1210
wikilinks = re.compile(r'href="/wiki/(.*?)"')
1311
htmltags = re.compile(r'<(noscript|script)>.*?<\/\1>', re.S)

pix2tex/utils/utils.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,11 @@ def num_model_params(model):
157157

158158
@contextlib.contextmanager
159159
def in_model_path():
160-
from importlib.resources import path
161-
with path('pix2tex', 'model') as model_path:
162-
saved = os.getcwd()
163-
os.chdir(model_path)
164-
try:
165-
yield
166-
finally:
167-
os.chdir(saved)
160+
import pix2tex
161+
model_path = os.path.join(os.path.dirname(pix2tex.__file__), 'model')
162+
saved = os.getcwd()
163+
os.chdir(model_path)
164+
try:
165+
yield
166+
finally:
167+
os.chdir(saved)

setup.py

+22-23
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# read the contents of your README file
66
from pathlib import Path
77
this_directory = Path(__file__).parent
8-
long_description = (this_directory / "README.md").read_text()
8+
long_description = (this_directory / 'README.md').read_text()
99

1010
gui = [
1111
"PyQt5",
@@ -21,8 +21,8 @@
2121

2222
setuptools.setup(
2323
name='pix2tex',
24-
version='0.0.12',
25-
description="pix2tex: Using a ViT to convert images of equations into LaTeX code.",
24+
version='0.0.14',
25+
description='pix2tex: Using a ViT to convert images of equations into LaTeX code.',
2626
long_description=long_description,
2727
long_description_content_type='text/markdown',
2828
author='Lukas Blecher',
@@ -43,26 +43,25 @@
4343
]
4444
},
4545
install_requires=[
46-
"tqdm>=4.47.0",
47-
"munch>=2.5.0",
48-
"torch>=1.7.1",
49-
"torchvision>=0.8.1",
50-
"opencv_python_headless>=4.1.1.26",
51-
"requests>=2.22.0",
52-
"einops>=0.3.0",
53-
"chardet>=3.0.4",
54-
"x_transformers==0.15.0",
55-
"imagesize>=1.2.0",
56-
"transformers==4.2.2",
57-
"tokenizers==0.9.4",
58-
"numpy>=1.19.5",
59-
"Pillow>=9.1.0",
60-
"PyYAML>=5.4.1",
61-
"torchtext>=0.6.0",
62-
"albumentations>=0.5.2",
63-
"pandas>=1.0.0",
64-
"timm",
65-
"python-Levenshtein>=0.12.2",
46+
'tqdm>=4.47.0',
47+
'munch>=2.5.0',
48+
'torch>=1.7.1',
49+
'opencv_python_headless>=4.1.1.26',
50+
'requests>=2.22.0',
51+
'einops>=0.3.0',
52+
'x_transformers==0.15.0',
53+
'transformers>=4.18.0',
54+
'tokenizers==0.12.1',
55+
'numpy>=1.19.5',
56+
'Pillow>=9.1.0',
57+
'PyYAML>=5.4.1',
58+
'pandas>=1.0.0',
59+
'timm',
60+
'chardet>=3.0.4',
61+
'python-Levenshtein>=0.12.2',
62+
'torchtext>=0.6.0',
63+
'albumentations>=0.5.2',
64+
'imagesize>=1.2.0',
6665
],
6766
extras_require={
6867
"all": gui+api,

0 commit comments

Comments
 (0)