Skip to content

Commit

Permalink
move data_parallel to Model, use generate in eval
Browse files Browse the repository at this point in the history
This reverts some parts of commit e2b55fb.
  • Loading branch information
lukas-blecher committed May 20, 2022
1 parent e2b55fb commit bd66642
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 76 deletions.
66 changes: 34 additions & 32 deletions notebooks/LaTeX_OCR_test.ipynb
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "LaTeX OCR test.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "aaAqi3wku23I"
},
"source": [
"# LaTeX OCR\n",
"In this colab you can convert an image of an equation into LaTeX code.\n",
Expand All @@ -14,7 +27,10 @@
"Next, execute the cell below and upload the image(s).\n",
"\n",
"> Note: You can probably also run this project locally and with a GUI. Follow the steps on [GitHub](https://github.com/lukas-blecher/LaTeX-OCR)"
]
],
"metadata": {
"id": "aaAqi3wku23I"
}
},
{
"cell_type": "code",
Expand Down Expand Up @@ -55,46 +71,32 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CjrR3O07u3uH"
},
"outputs": [],
"source": [
"imgs = upload_files()\n",
"predictions = []\n",
"for name, f in imgs:\n",
" img = Image.open(f)\n",
" math = model.generate(img)\n",
" math = model(img)\n",
" print(math)\n",
" predictions.append('\\\\mathrm{%s} & \\\\displaystyle{%s}'%(name, math))\n",
"Math(table%'\\\\\\\\'.join(predictions))"
]
],
"metadata": {
"id": "CjrR3O07u3uH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"source": [
""
],
"metadata": {
"id": "ZqCH-4XoCkMO"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "LaTeX OCR test.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
"execution_count": null,
"outputs": []
}
},
"nbformat": 4,
"nbformat_minor": 0
}
]
}
4 changes: 2 additions & 2 deletions pix2tex/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def predict(file: UploadFile = File(...)) -> str:
"""
global model
image = Image.open(file.file)
return model.generate(image)
return model(image)


@app.post('/bytes/')
Expand All @@ -61,4 +61,4 @@ async def predict_from_bytes(file: bytes = File(...)) -> str: # , size: str = F
global model
#size = tuple(int(a) for a in size.split(','))
image = Image.open(BytesIO(file))
return model.generate(image, resize=False)
return model(image, resize=False)
12 changes: 4 additions & 8 deletions pix2tex/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, arguments=None):
download_checkpoints()
self.model = get_model(self.args)
self.model.load_state_dict(torch.load(self.args.checkpoint, map_location=self.args.device))
self.model.eval()

if 'image_resizer.pth' in os.listdir(os.path.dirname(self.args.checkpoint)) and not arguments.no_resize:
self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(self.args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05,
Expand Down Expand Up @@ -123,13 +124,8 @@ def __call__(self, img=None, resize=True) -> str:
t = test_transform(image=img)['image'][:1].unsqueeze(0)
im = t.to(self.args.device)

with torch.no_grad():
self.model.eval()
device = self.args.device
encoded = self.model.encoder(im.to(device))
dec = self.model.decoder.generate(torch.LongTensor([self.args.bos_token])[:, None].to(device), self.args.max_seq_len,
eos_token=self.args.eos_token, context=encoded.detach(), temperature=self.args.get('temperature', .25))
pred = post_process(token2str(dec, self.tokenizer)[0])
dec = self.model.generate(im.to(self.args.device), temperature=self.args.get('temperature', .25))
pred = post_process(token2str(dec, self.tokenizer)[0])
try:
clipboard.copy(pred)
except:
Expand Down Expand Up @@ -220,7 +216,7 @@ def main():
img = ImageGrab.grabclipboard()
except:
pass
pred = model.generate(img)
pred = model(img)
output_prediction(pred, arguments)
except KeyboardInterrupt:
pass
Expand Down
4 changes: 1 addition & 3 deletions pix2tex/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
for i, (seq, im) in pbar:
if seq is None or im is None:
continue
encoded = model.encoder(im.to(device))
#loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len,
eos_token=args.pad_token, context=encoded, temperature=args.get('temperature', .2))
dec = model.generate(im.to(device), temperature=args.get('temperature', .2))
pred = detokenize(dec, dataset.tokenizer)
truth = detokenize(seq['input_ids'], dataset.tokenizer)
bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth]))
Expand Down
21 changes: 17 additions & 4 deletions pix2tex/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,28 @@ def __init__(self, encoder, decoder, args):
self.decoder = decoder
self.args = args

def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs):
def data_parallel(self, x: torch.Tensor, device_ids, output_device=None, **kwargs):
if not device_ids or len(device_ids) == 1:
return self(x, **kwargs)
if output_device is None:
output_device = device_ids[0]
replicas = nn.parallel.replicate(self, device_ids)
inputs = nn.parallel.scatter(x, device_ids) # Slices tensors into approximately equal chunks and distributes them across given GPUs.
kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors.
replicas = replicas[:len(inputs)]
kwargs = kwargs[:len(inputs)]
outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs)
return nn.parallel.gather(outputs, output_device).mean()

def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs):
encoded = self.encoder(x)
out = self.decoder(tgt_seq, context=encoded, **kwargs)
return out

@torch.no_grad()
def generate(self, x: torch.Tensor):
return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device),
self.args.max_seq_len, eos_token=self.args.eos_token, context=self.encoder(x))
def generate(self, x: torch.Tensor, temperature: float = 0.25):
return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device), self.args.max_seq_len,
eos_token=self.args.eos_token, context=self.encoder(x), temperature=temperature)


def get_model(args):
Expand Down
31 changes: 4 additions & 27 deletions pix2tex/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,14 @@
from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler


def data_parallel(module, x:torch.Tensor, device_ids, output_device=None, **kwargs):
if not device_ids or len(device_ids) == 1:
return module(x, **kwargs)
if output_device is None:
output_device = device_ids[0]
replicas = nn.parallel.replicate(module, device_ids)
inputs = nn.parallel.scatter(x, device_ids) # Slices tensors into approximately equal chunks and distributes them across given GPUs.
kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors.
replicas = replicas[:len(inputs)]
kwargs = kwargs[:len(inputs)]
outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs)
return nn.parallel.gather(outputs, output_device)


def gpu_memory_check(model, args):
# check if largest batch can be handled by system
try:
batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize
for _ in range(5):
im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long()
# model.decoder(seq, context=model.encoder(im)).sum().backward()
# encoded = data_parallel(model.encoder, inputs=im, device_ids=args.gpu_devices)
# loss = data_parallel(model.decoder, inputs=seq, device_ids=args.gpu_devices, context=encoded)
loss = data_parallel(model, im, device_ids=args.gpu_devices, tgt_seq=seq)
loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq)
loss.sum().backward()
except RuntimeError:
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width))
Expand All @@ -60,7 +43,6 @@ def train(args):
gpu_memory_check(model, args)
if args.load_chkpt is not None:
model.load_state_dict(torch.load(args.load_chkpt, map_location=device))
encoder, decoder = model.encoder, model.decoder
max_bleu, max_token_acc = 0, 0
out_path = os.path.join(args.model_path, args.name)
os.makedirs(out_path, exist_ok=True)
Expand All @@ -86,14 +68,9 @@ def save_models(e, step=0):
total_loss = 0
for j in range(0, len(im), microbatch):
tgt_seq, tgt_mask = seq['input_ids'][j:j+microbatch].to(device), seq['attention_mask'][j:j+microbatch].bool().to(device)
# encoded = encoder(im[j:j+microbatch].to(device))
# encoded = data_parallel(encoder, inputs=im[j:j+microbatch].to(device), device_ids=args.gpu_devices)
# loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)*microbatch/args.batchsize
# loss = data_parallel(module=decoder, inputs=tgt_seq, device_ids=args.gpu_devices, mask=tgt_mask, context=encoded)*microbatch/args.batchsize
# loss.backward()
loss = data_parallel(model,im[j:j+microbatch].to(device), device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*microbatch/args.batchsize
loss.mean().backward() # data parallism loss is a vector
total_loss += loss.mean().item()
loss = model.data_parallel(im[j:j+microbatch].to(device), device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*microbatch/args.batchsize
loss.backward() # data parallism loss is a vector
total_loss += loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
opt.step()
scheduler.step()
Expand Down

0 comments on commit bd66642

Please sign in to comment.