Skip to content

Commit bd66642

Browse files
committed
move data_parallel to Model, use generate in eval
This reverts some parts of commit e2b55fb.
1 parent e2b55fb commit bd66642

File tree

6 files changed

+62
-76
lines changed

6 files changed

+62
-76
lines changed

notebooks/LaTeX_OCR_test.ipynb

+34-32
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
11
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"name": "LaTeX OCR test.ipynb",
7+
"provenance": [],
8+
"collapsed_sections": []
9+
},
10+
"kernelspec": {
11+
"name": "python3",
12+
"display_name": "Python 3"
13+
},
14+
"language_info": {
15+
"name": "python"
16+
}
17+
},
218
"cells": [
319
{
420
"cell_type": "markdown",
5-
"metadata": {
6-
"id": "aaAqi3wku23I"
7-
},
821
"source": [
922
"# LaTeX OCR\n",
1023
"In this colab you can convert an image of an equation into LaTeX code.\n",
@@ -14,7 +27,10 @@
1427
"Next, execute the cell below and upload the image(s).\n",
1528
"\n",
1629
"> Note: You can probably also run this project locally and with a GUI. Follow the steps on [GitHub](https://github.com/lukas-blecher/LaTeX-OCR)"
17-
]
30+
],
31+
"metadata": {
32+
"id": "aaAqi3wku23I"
33+
}
1834
},
1935
{
2036
"cell_type": "code",
@@ -55,46 +71,32 @@
5571
},
5672
{
5773
"cell_type": "code",
58-
"execution_count": null,
59-
"metadata": {
60-
"id": "CjrR3O07u3uH"
61-
},
62-
"outputs": [],
6374
"source": [
6475
"imgs = upload_files()\n",
6576
"predictions = []\n",
6677
"for name, f in imgs:\n",
6778
" img = Image.open(f)\n",
68-
" math = model.generate(img)\n",
79+
" math = model(img)\n",
6980
" print(math)\n",
7081
" predictions.append('\\\\mathrm{%s} & \\\\displaystyle{%s}'%(name, math))\n",
7182
"Math(table%'\\\\\\\\'.join(predictions))"
72-
]
83+
],
84+
"metadata": {
85+
"id": "CjrR3O07u3uH"
86+
},
87+
"execution_count": null,
88+
"outputs": []
7389
},
7490
{
7591
"cell_type": "code",
76-
"execution_count": null,
92+
"source": [
93+
""
94+
],
7795
"metadata": {
7896
"id": "ZqCH-4XoCkMO"
7997
},
80-
"outputs": [],
81-
"source": []
82-
}
83-
],
84-
"metadata": {
85-
"colab": {
86-
"collapsed_sections": [],
87-
"name": "LaTeX OCR test.ipynb",
88-
"provenance": []
89-
},
90-
"kernelspec": {
91-
"display_name": "Python 3",
92-
"name": "python3"
93-
},
94-
"language_info": {
95-
"name": "python"
98+
"execution_count": null,
99+
"outputs": []
96100
}
97-
},
98-
"nbformat": 4,
99-
"nbformat_minor": 0
100-
}
101+
]
102+
}

pix2tex/api/app.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ async def predict(file: UploadFile = File(...)) -> str:
4545
"""
4646
global model
4747
image = Image.open(file.file)
48-
return model.generate(image)
48+
return model(image)
4949

5050

5151
@app.post('/bytes/')
@@ -61,4 +61,4 @@ async def predict_from_bytes(file: bytes = File(...)) -> str: # , size: str = F
6161
global model
6262
#size = tuple(int(a) for a in size.split(','))
6363
image = Image.open(BytesIO(file))
64-
return model.generate(image, resize=False)
64+
return model(image, resize=False)

pix2tex/cli.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(self, arguments=None):
7575
download_checkpoints()
7676
self.model = get_model(self.args)
7777
self.model.load_state_dict(torch.load(self.args.checkpoint, map_location=self.args.device))
78+
self.model.eval()
7879

7980
if 'image_resizer.pth' in os.listdir(os.path.dirname(self.args.checkpoint)) and not arguments.no_resize:
8081
self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(self.args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05,
@@ -123,13 +124,8 @@ def __call__(self, img=None, resize=True) -> str:
123124
t = test_transform(image=img)['image'][:1].unsqueeze(0)
124125
im = t.to(self.args.device)
125126

126-
with torch.no_grad():
127-
self.model.eval()
128-
device = self.args.device
129-
encoded = self.model.encoder(im.to(device))
130-
dec = self.model.decoder.generate(torch.LongTensor([self.args.bos_token])[:, None].to(device), self.args.max_seq_len,
131-
eos_token=self.args.eos_token, context=encoded.detach(), temperature=self.args.get('temperature', .25))
132-
pred = post_process(token2str(dec, self.tokenizer)[0])
127+
dec = self.model.generate(im.to(self.args.device), temperature=self.args.get('temperature', .25))
128+
pred = post_process(token2str(dec, self.tokenizer)[0])
133129
try:
134130
clipboard.copy(pred)
135131
except:
@@ -220,7 +216,7 @@ def main():
220216
img = ImageGrab.grabclipboard()
221217
except:
222218
pass
223-
pred = model.generate(img)
219+
pred = model(img)
224220
output_prediction(pred, arguments)
225221
except KeyboardInterrupt:
226222
pass

pix2tex/eval.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,8 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
5050
for i, (seq, im) in pbar:
5151
if seq is None or im is None:
5252
continue
53-
encoded = model.encoder(im.to(device))
5453
#loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
55-
dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len,
56-
eos_token=args.pad_token, context=encoded, temperature=args.get('temperature', .2))
54+
dec = model.generate(im.to(device), temperature=args.get('temperature', .2))
5755
pred = detokenize(dec, dataset.tokenizer)
5856
truth = detokenize(seq['input_ids'], dataset.tokenizer)
5957
bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth]))

pix2tex/models/utils.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,28 @@ def __init__(self, encoder, decoder, args):
1313
self.decoder = decoder
1414
self.args = args
1515

16-
def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs):
16+
def data_parallel(self, x: torch.Tensor, device_ids, output_device=None, **kwargs):
17+
if not device_ids or len(device_ids) == 1:
18+
return self(x, **kwargs)
19+
if output_device is None:
20+
output_device = device_ids[0]
21+
replicas = nn.parallel.replicate(self, device_ids)
22+
inputs = nn.parallel.scatter(x, device_ids) # Slices tensors into approximately equal chunks and distributes them across given GPUs.
23+
kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors.
24+
replicas = replicas[:len(inputs)]
25+
kwargs = kwargs[:len(inputs)]
26+
outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs)
27+
return nn.parallel.gather(outputs, output_device).mean()
28+
29+
def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs):
1730
encoded = self.encoder(x)
1831
out = self.decoder(tgt_seq, context=encoded, **kwargs)
1932
return out
2033

2134
@torch.no_grad()
22-
def generate(self, x: torch.Tensor):
23-
return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device),
24-
self.args.max_seq_len, eos_token=self.args.eos_token, context=self.encoder(x))
35+
def generate(self, x: torch.Tensor, temperature: float = 0.25):
36+
return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device), self.args.max_seq_len,
37+
eos_token=self.args.eos_token, context=self.encoder(x), temperature=temperature)
2538

2639

2740
def get_model(args):

pix2tex/train.py

+4-27
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,14 @@
1515
from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler
1616

1717

18-
def data_parallel(module, x:torch.Tensor, device_ids, output_device=None, **kwargs):
19-
if not device_ids or len(device_ids) == 1:
20-
return module(x, **kwargs)
21-
if output_device is None:
22-
output_device = device_ids[0]
23-
replicas = nn.parallel.replicate(module, device_ids)
24-
inputs = nn.parallel.scatter(x, device_ids) # Slices tensors into approximately equal chunks and distributes them across given GPUs.
25-
kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors.
26-
replicas = replicas[:len(inputs)]
27-
kwargs = kwargs[:len(inputs)]
28-
outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs)
29-
return nn.parallel.gather(outputs, output_device)
30-
31-
3218
def gpu_memory_check(model, args):
3319
# check if largest batch can be handled by system
3420
try:
3521
batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize
3622
for _ in range(5):
3723
im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
3824
seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long()
39-
# model.decoder(seq, context=model.encoder(im)).sum().backward()
40-
# encoded = data_parallel(model.encoder, inputs=im, device_ids=args.gpu_devices)
41-
# loss = data_parallel(model.decoder, inputs=seq, device_ids=args.gpu_devices, context=encoded)
42-
loss = data_parallel(model, im, device_ids=args.gpu_devices, tgt_seq=seq)
25+
loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq)
4326
loss.sum().backward()
4427
except RuntimeError:
4528
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width))
@@ -60,7 +43,6 @@ def train(args):
6043
gpu_memory_check(model, args)
6144
if args.load_chkpt is not None:
6245
model.load_state_dict(torch.load(args.load_chkpt, map_location=device))
63-
encoder, decoder = model.encoder, model.decoder
6446
max_bleu, max_token_acc = 0, 0
6547
out_path = os.path.join(args.model_path, args.name)
6648
os.makedirs(out_path, exist_ok=True)
@@ -86,14 +68,9 @@ def save_models(e, step=0):
8668
total_loss = 0
8769
for j in range(0, len(im), microbatch):
8870
tgt_seq, tgt_mask = seq['input_ids'][j:j+microbatch].to(device), seq['attention_mask'][j:j+microbatch].bool().to(device)
89-
# encoded = encoder(im[j:j+microbatch].to(device))
90-
# encoded = data_parallel(encoder, inputs=im[j:j+microbatch].to(device), device_ids=args.gpu_devices)
91-
# loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)*microbatch/args.batchsize
92-
# loss = data_parallel(module=decoder, inputs=tgt_seq, device_ids=args.gpu_devices, mask=tgt_mask, context=encoded)*microbatch/args.batchsize
93-
# loss.backward()
94-
loss = data_parallel(model,im[j:j+microbatch].to(device), device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*microbatch/args.batchsize
95-
loss.mean().backward() # data parallism loss is a vector
96-
total_loss += loss.mean().item()
71+
loss = model.data_parallel(im[j:j+microbatch].to(device), device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*microbatch/args.batchsize
72+
loss.backward() # data parallism loss is a vector
73+
total_loss += loss.item()
9774
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
9875
opt.step()
9976
scheduler.step()

0 commit comments

Comments
 (0)