Skip to content

Commit

Permalink
restructure model code
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-blecher committed May 17, 2022
1 parent 27b620f commit 5fca76e
Show file tree
Hide file tree
Showing 14 changed files with 162 additions and 183 deletions.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Welcome to LaTeX-OCR's documentation!
pix2tex.gui
pix2tex.api
pix2tex.dataset
pix2tex.structures.hybrid
pix2tex.models
pix2tex.utils


Expand Down
14 changes: 11 additions & 3 deletions docs/pix2tex.models.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
pix2tex.structures.hybrid package
======================
pix2tex.models.hybrid package
=============================

.. automodule:: pix2tex.structures.hybrid
.. automodule:: pix2tex.models.hybrid
:members:
:no-undoc-members:
:show-inheritance:

pix2tex.models.vit package
==========================

.. automodule:: pix2tex.models.vit
:members:
:no-undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion pix2tex/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from timm.models.layers import StdConv2dSame

from pix2tex.dataset.latex2png import tex2pil
from pix2tex.structures.hybrid import get_model
from pix2tex.models import get_model
from pix2tex.utils import *
from pix2tex.model.checkpoints.get_latest_checkpoint import download_checkpoints

Expand Down
2 changes: 1 addition & 1 deletion pix2tex/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import wandb
from Levenshtein import distance

from pix2tex.structures.hybrid import get_model, Model
from pix2tex.models import get_model, Model
from pix2tex.utils import *


Expand Down
2 changes: 2 additions & 0 deletions pix2tex/model/settings/config-vit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ decoder_args:
rel_pos_bias: false
use_scalenorm: false
dim: 256
emb_dropout: 0
encoder_depth: 4
eos_token: 2
epochs: 10
Expand Down Expand Up @@ -41,6 +42,7 @@ sample_freq: 1000
save_freq: 5
scheduler: StepLR
seed: 42
encoder_structure: vit
temperature: 0.2
test_samples: 5
testbatchsize: 20
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ sample_freq: 3000
save_freq: 5
scheduler: StepLR
seed: 42
encoder_structure: hybrid
temperature: 0.2
test_samples: 5
testbatchsize: 20
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ decoder_args:
heads: 8
num_tokens: 8000
max_seq_len: 1024
encoder_structure: hybrid

# Other
seed: 42
Expand Down
1 change: 1 addition & 0 deletions pix2tex/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import *
41 changes: 2 additions & 39 deletions pix2tex/structures/hybrid.py → pix2tex/models/hybrid.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# taken and modified from https://github.com/lukas-blecher/LaTeX-OCR/blob/720978d8c469780ed070d041d5795c55b705ac1b/pix2tex/models.py
import torch
import torch.nn as nn
import torch.nn.functional as F

# from x_transformers import *
from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper, top_k, top_p, entmax, ENTMAX_ALPHA
from timm.models.vision_transformer import VisionTransformer
from timm.models.vision_transformer_hybrid import HybridEmbed
Expand Down Expand Up @@ -108,7 +105,7 @@ def forward(self, x: torch.Tensor):
return dec


def get_model(args, training=False):
def get_encoder(args, training=False):
backbone = ResNetV2(
layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels,
preact=False, stem_type='same', conv_layer=StdConv2dSame)
Expand All @@ -128,38 +125,4 @@ def embed_layer(**x):
num_heads=args.heads,
embed_layer=embed_layer
)

decoder = CustomARWrapper(
TransformerWrapper(
num_tokens=args.num_tokens,
max_seq_len=args.max_seq_len,
attn_layers=Decoder(
dim=args.dim,
depth=args.num_layers,
heads=args.heads,
**args.decoder_args
)),
pad_value=args.pad_token
)
#to device
available_gpus = torch.cuda.device_count()
if available_gpus > 1:
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
encoder.to(args.device)
decoder.to(args.device)
if 'wandb' in args and args.wandb:
import wandb
de_attn_layers = decoder.module.net.attn_layers if available_gpus > 1 else decoder.net.attn_layers
wandb.watch((encoder, de_attn_layers))
model = Model(encoder, decoder, args)
if training:
# check if largest batch can be handled by system
batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize
im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long()
decoder(seq, context=encoder(im)).sum().backward()
model.zero_grad()
torch.cuda.empty_cache()
del im, seq
return model
return encoder
65 changes: 65 additions & 0 deletions pix2tex/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import torch.nn as nn

from x_transformers import TransformerWrapper, Encoder, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

from . import hybrid
from . import vit


class Model(nn.Module):
def __init__(self, encoder: Encoder, decoder: AutoregressiveWrapper, args):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.args = args

def forward(self, x: torch.Tensor):
return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device), self.args.max_seq_len, eos_token=self.args.eos_token, context=self.encoder(x))


def get_model(args, training=False):
if args.encoder_structure.lower() == 'vit':
encoder = vit.get_encoder(args)
elif args.encoder_structure.lower() == 'hybrid':
encoder = hybrid.get_encoder(args)
else:
raise NotImplementedError('Encoder structure "%s" not supported.' % args.encoder_structure)
decoder = AutoregressiveWrapper(
TransformerWrapper(
num_tokens=args.num_tokens,
max_seq_len=args.max_seq_len,
attn_layers=Decoder(
dim=args.dim,
depth=args.num_layers,
heads=args.heads,
cross_attend=True
)),
pad_value=args.pad_token
)
available_gpus = torch.cuda.device_count()
if available_gpus > 1:
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
encoder.to(args.device)
decoder.to(args.device)
if args.wandb:
import wandb
en_attn_layers = encoder.module.attn_layers if available_gpus > 1 else encoder.attn_layers
de_attn_layers = decoder.module.net.attn_layers if available_gpus > 1 else decoder.net.attn_layers
wandb.watch((en_attn_layers, de_attn_layers))
model = Model(encoder, decoder, args)
if training:
# check if largest batch can be handled by system
batchsize = args.batchsize if args.get(
'micro_batchsize', -1) == -1 else args.micro_batchsize
im = torch.empty(batchsize, args.channels, args.max_height,
args.min_height, device=args.device).float()
seq = torch.randint(0, args.num_tokens, (batchsize,
args.max_seq_len), device=args.device).long()
decoder(seq, context=encoder(im)).sum().backward()
model.zero_grad()
torch.cuda.empty_cache()
del im, seq
return model
74 changes: 74 additions & 0 deletions pix2tex/models/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
import torch.nn as nn

from x_transformers import Encoder
from einops import rearrange, repeat


class ViTransformerWrapper(nn.Module):
def __init__(
self,
*,
max_width,
max_height,
patch_size,
attn_layers,
channels=1,
num_classes=None,
dropout=0.,
emb_dropout=0.
):
super().__init__()
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
assert max_width % patch_size == 0 and max_height % patch_size == 0, 'image dimensions must be divisible by the patch size'
dim = attn_layers.dim
num_patches = (max_width // patch_size)*(max_height // patch_size)
patch_dim = channels * patch_size ** 2

self.patch_size = patch_size
self.max_width = max_width
self.max_height = max_height

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)

self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
#self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None

def forward(self, img, **kwargs):
p = self.patch_size

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
h, w = torch.tensor(img.shape[2:])//p
pos_emb_ind = repeat(torch.arange(h)*(self.max_width//p-w), 'h -> (h w)', w=w)+torch.arange(h*w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
x += self.pos_embedding[:, pos_emb_ind]
x = self.dropout(x)

x = self.attn_layers(x, **kwargs)
x = self.norm(x)

return x


def get_encoder(args):
return ViTransformerWrapper(
max_width=args.max_width,
max_height=args.max_height,
channels=args.channels,
patch_size=args.patch_size,
emb_dropout=args.get('emb_dropout', 0),
attn_layers=Encoder(
dim=args.dim,
depth=args.num_layers,
heads=args.heads,
)
)
Loading

0 comments on commit 5fca76e

Please sign in to comment.