Skip to content

Commit 5fca76e

Browse files
committed
restructure model code
1 parent 27b620f commit 5fca76e

File tree

14 files changed

+162
-183
lines changed

14 files changed

+162
-183
lines changed

docs/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Welcome to LaTeX-OCR's documentation!
1414
pix2tex.gui
1515
pix2tex.api
1616
pix2tex.dataset
17-
pix2tex.structures.hybrid
17+
pix2tex.models
1818
pix2tex.utils
1919

2020

docs/pix2tex.models.rst

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1-
pix2tex.structures.hybrid package
2-
======================
1+
pix2tex.models.hybrid package
2+
=============================
33

4-
.. automodule:: pix2tex.structures.hybrid
4+
.. automodule:: pix2tex.models.hybrid
5+
:members:
6+
:no-undoc-members:
7+
:show-inheritance:
8+
9+
pix2tex.models.vit package
10+
==========================
11+
12+
.. automodule:: pix2tex.models.vit
513
:members:
614
:no-undoc-members:
715
:show-inheritance:

pix2tex/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from timm.models.layers import StdConv2dSame
1818

1919
from pix2tex.dataset.latex2png import tex2pil
20-
from pix2tex.structures.hybrid import get_model
20+
from pix2tex.models import get_model
2121
from pix2tex.utils import *
2222
from pix2tex.model.checkpoints.get_latest_checkpoint import download_checkpoints
2323

pix2tex/eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import wandb
1212
from Levenshtein import distance
1313

14-
from pix2tex.structures.hybrid import get_model, Model
14+
from pix2tex.models import get_model, Model
1515
from pix2tex.utils import *
1616

1717

pix2tex/model/settings/config-vit.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ decoder_args:
1313
rel_pos_bias: false
1414
use_scalenorm: false
1515
dim: 256
16+
emb_dropout: 0
1617
encoder_depth: 4
1718
eos_token: 2
1819
epochs: 10
@@ -41,6 +42,7 @@ sample_freq: 1000
4142
save_freq: 5
4243
scheduler: StepLR
4344
seed: 42
45+
encoder_structure: vit
4446
temperature: 0.2
4547
test_samples: 5
4648
testbatchsize: 20

pix2tex/model/settings/config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ sample_freq: 3000
4545
save_freq: 5
4646
scheduler: StepLR
4747
seed: 42
48+
encoder_structure: hybrid
4849
temperature: 0.2
4950
test_samples: 5
5051
testbatchsize: 20

pix2tex/model/settings/debug.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ decoder_args:
5151
heads: 8
5252
num_tokens: 8000
5353
max_seq_len: 1024
54+
encoder_structure: hybrid
5455

5556
# Other
5657
seed: 42

pix2tex/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .utils import *

pix2tex/structures/hybrid.py pix2tex/models/hybrid.py

+2-39
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
# taken and modified from https://github.com/lukas-blecher/LaTeX-OCR/blob/720978d8c469780ed070d041d5795c55b705ac1b/pix2tex/models.py
21
import torch
32
import torch.nn as nn
43
import torch.nn.functional as F
54

6-
# from x_transformers import *
7-
from x_transformers import TransformerWrapper, Decoder
85
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper, top_k, top_p, entmax, ENTMAX_ALPHA
96
from timm.models.vision_transformer import VisionTransformer
107
from timm.models.vision_transformer_hybrid import HybridEmbed
@@ -108,7 +105,7 @@ def forward(self, x: torch.Tensor):
108105
return dec
109106

110107

111-
def get_model(args, training=False):
108+
def get_encoder(args, training=False):
112109
backbone = ResNetV2(
113110
layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels,
114111
preact=False, stem_type='same', conv_layer=StdConv2dSame)
@@ -128,38 +125,4 @@ def embed_layer(**x):
128125
num_heads=args.heads,
129126
embed_layer=embed_layer
130127
)
131-
132-
decoder = CustomARWrapper(
133-
TransformerWrapper(
134-
num_tokens=args.num_tokens,
135-
max_seq_len=args.max_seq_len,
136-
attn_layers=Decoder(
137-
dim=args.dim,
138-
depth=args.num_layers,
139-
heads=args.heads,
140-
**args.decoder_args
141-
)),
142-
pad_value=args.pad_token
143-
)
144-
#to device
145-
available_gpus = torch.cuda.device_count()
146-
if available_gpus > 1:
147-
encoder = nn.DataParallel(encoder)
148-
decoder = nn.DataParallel(decoder)
149-
encoder.to(args.device)
150-
decoder.to(args.device)
151-
if 'wandb' in args and args.wandb:
152-
import wandb
153-
de_attn_layers = decoder.module.net.attn_layers if available_gpus > 1 else decoder.net.attn_layers
154-
wandb.watch((encoder, de_attn_layers))
155-
model = Model(encoder, decoder, args)
156-
if training:
157-
# check if largest batch can be handled by system
158-
batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize
159-
im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
160-
seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long()
161-
decoder(seq, context=encoder(im)).sum().backward()
162-
model.zero_grad()
163-
torch.cuda.empty_cache()
164-
del im, seq
165-
return model
128+
return encoder

pix2tex/models/utils.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from x_transformers import TransformerWrapper, Encoder, Decoder
5+
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
6+
7+
from . import hybrid
8+
from . import vit
9+
10+
11+
class Model(nn.Module):
12+
def __init__(self, encoder: Encoder, decoder: AutoregressiveWrapper, args):
13+
super().__init__()
14+
self.encoder = encoder
15+
self.decoder = decoder
16+
self.args = args
17+
18+
def forward(self, x: torch.Tensor):
19+
return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device), self.args.max_seq_len, eos_token=self.args.eos_token, context=self.encoder(x))
20+
21+
22+
def get_model(args, training=False):
23+
if args.encoder_structure.lower() == 'vit':
24+
encoder = vit.get_encoder(args)
25+
elif args.encoder_structure.lower() == 'hybrid':
26+
encoder = hybrid.get_encoder(args)
27+
else:
28+
raise NotImplementedError('Encoder structure "%s" not supported.' % args.encoder_structure)
29+
decoder = AutoregressiveWrapper(
30+
TransformerWrapper(
31+
num_tokens=args.num_tokens,
32+
max_seq_len=args.max_seq_len,
33+
attn_layers=Decoder(
34+
dim=args.dim,
35+
depth=args.num_layers,
36+
heads=args.heads,
37+
cross_attend=True
38+
)),
39+
pad_value=args.pad_token
40+
)
41+
available_gpus = torch.cuda.device_count()
42+
if available_gpus > 1:
43+
encoder = nn.DataParallel(encoder)
44+
decoder = nn.DataParallel(decoder)
45+
encoder.to(args.device)
46+
decoder.to(args.device)
47+
if args.wandb:
48+
import wandb
49+
en_attn_layers = encoder.module.attn_layers if available_gpus > 1 else encoder.attn_layers
50+
de_attn_layers = decoder.module.net.attn_layers if available_gpus > 1 else decoder.net.attn_layers
51+
wandb.watch((en_attn_layers, de_attn_layers))
52+
model = Model(encoder, decoder, args)
53+
if training:
54+
# check if largest batch can be handled by system
55+
batchsize = args.batchsize if args.get(
56+
'micro_batchsize', -1) == -1 else args.micro_batchsize
57+
im = torch.empty(batchsize, args.channels, args.max_height,
58+
args.min_height, device=args.device).float()
59+
seq = torch.randint(0, args.num_tokens, (batchsize,
60+
args.max_seq_len), device=args.device).long()
61+
decoder(seq, context=encoder(im)).sum().backward()
62+
model.zero_grad()
63+
torch.cuda.empty_cache()
64+
del im, seq
65+
return model

pix2tex/models/vit.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from x_transformers import Encoder
5+
from einops import rearrange, repeat
6+
7+
8+
class ViTransformerWrapper(nn.Module):
9+
def __init__(
10+
self,
11+
*,
12+
max_width,
13+
max_height,
14+
patch_size,
15+
attn_layers,
16+
channels=1,
17+
num_classes=None,
18+
dropout=0.,
19+
emb_dropout=0.
20+
):
21+
super().__init__()
22+
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
23+
assert max_width % patch_size == 0 and max_height % patch_size == 0, 'image dimensions must be divisible by the patch size'
24+
dim = attn_layers.dim
25+
num_patches = (max_width // patch_size)*(max_height // patch_size)
26+
patch_dim = channels * patch_size ** 2
27+
28+
self.patch_size = patch_size
29+
self.max_width = max_width
30+
self.max_height = max_height
31+
32+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
33+
self.patch_to_embedding = nn.Linear(patch_dim, dim)
34+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
35+
self.dropout = nn.Dropout(emb_dropout)
36+
37+
self.attn_layers = attn_layers
38+
self.norm = nn.LayerNorm(dim)
39+
#self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None
40+
41+
def forward(self, img, **kwargs):
42+
p = self.patch_size
43+
44+
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
45+
x = self.patch_to_embedding(x)
46+
b, n, _ = x.shape
47+
48+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
49+
x = torch.cat((cls_tokens, x), dim=1)
50+
h, w = torch.tensor(img.shape[2:])//p
51+
pos_emb_ind = repeat(torch.arange(h)*(self.max_width//p-w), 'h -> (h w)', w=w)+torch.arange(h*w)
52+
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
53+
x += self.pos_embedding[:, pos_emb_ind]
54+
x = self.dropout(x)
55+
56+
x = self.attn_layers(x, **kwargs)
57+
x = self.norm(x)
58+
59+
return x
60+
61+
62+
def get_encoder(args):
63+
return ViTransformerWrapper(
64+
max_width=args.max_width,
65+
max_height=args.max_height,
66+
channels=args.channels,
67+
patch_size=args.patch_size,
68+
emb_dropout=args.get('emb_dropout', 0),
69+
attn_layers=Encoder(
70+
dim=args.dim,
71+
depth=args.num_layers,
72+
heads=args.heads,
73+
)
74+
)

0 commit comments

Comments
 (0)