Skip to content

Commit

Permalink
min dim in eval and snip shortcut
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-blecher committed Jun 6, 2021
1 parent b7bb485 commit 4a61993
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
10 changes: 7 additions & 3 deletions gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from PyQt5 import QtCore, QtGui
from PyQt5.QtCore import QObject, Qt, pyqtSlot, pyqtSignal, QThread
from PyQt5.QtWebEngineWidgets import QWebEngineView
from PyQt5.QtWidgets import QMainWindow, QApplication, QMessageBox, QVBoxLayout, QWidget,\
from PyQt5.QtGui import QKeySequence
from PyQt5.QtWidgets import QMainWindow, QApplication, QMessageBox, QVBoxLayout, QWidget, QShortcut,\
QPushButton, QTextEdit, QLineEdit, QFormLayout, QHBoxLayout, QCheckBox, QSpinBox, QDoubleSpinBox
from resources import resources
from pynput.mouse import Controller
Expand Down Expand Up @@ -59,9 +60,12 @@ def initUI(self):
self.tempField.setSingleStep(0.1)

# Create snip button
self.snipButton = QPushButton('Snip', self)
self.snipButton = QPushButton('Snip [Alt+S]', self)
self.snipButton.clicked.connect(self.onClick)

self.shortcut = QShortcut(QKeySequence("Alt+S"), self)
self.shortcut.activated.connect(self.onClick)

# Create retry button
self.retryButton = QPushButton('Retry', self)
self.retryButton.setEnabled(False)
Expand Down Expand Up @@ -165,7 +169,7 @@ def run(self):
try:
prediction = pix2tex.call_model(self.args, *self.objs, img=self.img)
# replace <, > with \lt, \gt so it won't be interpreted as html code
prediction = prediction.replace('<','\\lt ').replace('>','\\gt ')
prediction = prediction.replace('<', '\\lt ').replace('>', '\\gt ')
self.finished.emit({"success": True, "prediction": prediction})
except Exception as e:
print(e)
Expand Down
20 changes: 13 additions & 7 deletions pix2tex.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,17 @@
last_pic = None


def minmax_size(img, max_dimensions):
ratios = [a/b for a, b in zip(img.size, max_dimensions)]
if any([r > 1 for r in ratios]):
size = np.array(img.size)//max(ratios)
img = img.resize(size.astype(int), Image.BILINEAR)
def minmax_size(img, max_dimensions=None, min_dimensions=None):
if max_dimensions is not None:
ratios = [a/b for a, b in zip(img.size, max_dimensions)]
if any([r > 1 for r in ratios]):
size = np.array(img.size)//max(ratios)
img = img.resize(size.astype(int), Image.BILINEAR)
if min_dimensions is not None:
if any([s < min_dimensions[i] for i, s in enumerate(img.size)]):
padded_im = Image.new('L', min_dimensions, 255)
padded_im.paste(img, img.getbbox())
img = padded_im
return img


Expand Down Expand Up @@ -72,13 +78,13 @@ def call_model(args, model, image_resizer, tokenizer, img=None):
img = last_pic.copy()
else:
last_pic = img.copy()
img = minmax_size(pad(img), args.max_dimensions)
img = minmax_size(pad(img), args.max_dimensions, args.min_dimensions)
if image_resizer is not None and not args.no_resize:
with torch.no_grad():
input_image = pad(img).convert('RGB').copy()
r, w = 1, img.size[0]
for i in range(10):
img = minmax_size(input_image.resize((w, int(input_image.size[1]*r)), Image.BILINEAR if r > 1 else Image.LANCZOS), args.max_dimensions)
img = minmax_size(input_image.resize((w, int(input_image.size[1]*r)), Image.BILINEAR if r > 1 else Image.LANCZOS), args.max_dimensions, args.min_dimensions)
t = test_transform(image=np.array(pad(img).convert('RGB')))['image'][:1].unsqueeze(0)
w = image_resizer(t.to(args.device)).argmax(-1).item()*32
if (w/img.size[0] == 1):
Expand Down

0 comments on commit 4a61993

Please sign in to comment.