Skip to content

Commit

Permalink
Merge pull request lukas-blecher#150 from TITC/data-parallelism
Browse files Browse the repository at this point in the history
Data parallelism【multi-gpu train】+pure ViT work + small modify
  • Loading branch information
lukas-blecher authored May 20, 2022
2 parents cac7f3a + 67d46d8 commit 06b7a9a
Show file tree
Hide file tree
Showing 14 changed files with 358 additions and 183 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,5 @@ pix2tex/model/checkpoints/**
!**/.gitkeep
.vscode
.DS_Store
test/*

10 changes: 3 additions & 7 deletions pix2tex/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, arguments=None):
download_checkpoints()
self.model = get_model(self.args)
self.model.load_state_dict(torch.load(self.args.checkpoint, map_location=self.args.device))
self.model.eval()

if 'image_resizer.pth' in os.listdir(os.path.dirname(self.args.checkpoint)) and not arguments.no_resize:
self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(self.args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05,
Expand Down Expand Up @@ -123,13 +124,8 @@ def __call__(self, img=None, resize=True) -> str:
t = test_transform(image=img)['image'][:1].unsqueeze(0)
im = t.to(self.args.device)

with torch.no_grad():
self.model.eval()
device = self.args.device
encoded = self.model.encoder(im.to(device))
dec = self.model.decoder.generate(torch.LongTensor([self.args.bos_token])[:, None].to(device), self.args.max_seq_len,
eos_token=self.args.eos_token, context=encoded.detach(), temperature=self.args.get('temperature', .25))
pred = post_process(token2str(dec, self.tokenizer)[0])
dec = self.model.generate(im.to(self.args.device), temperature=self.args.get('temperature', .25))
pred = post_process(token2str(dec, self.tokenizer)[0])
try:
clipboard.copy(pred)
except:
Expand Down
4 changes: 1 addition & 3 deletions pix2tex/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
for i, (seq, im) in pbar:
if seq is None or im is None:
continue
encoded = model.encoder(im.to(device))
#loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len,
eos_token=args.pad_token, context=encoded, temperature=args.get('temperature', .2))
dec = model.generate(im.to(device), temperature=args.get('temperature', .2))
pred = detokenize(dec, dataset.tokenizer)
truth = detokenize(seq['input_ids'], dataset.tokenizer)
bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth]))
Expand Down
52 changes: 52 additions & 0 deletions pix2tex/model/settings/config-vit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
betas:
- 0.9
- 0.999
batchsize: 64
bos_token: 1
channels: 1
data: dataset/data/train.pkl
debug: false
decoder_args:
attn_on_attn: true
cross_attend: true
ff_glu: true
rel_pos_bias: false
use_scalenorm: false
dim: 256
emb_dropout: 0
encoder_depth: 4
eos_token: 2
epochs: 10
gamma: 0.9995
heads: 8
id: null
load_chkpt: null
lr: 0.0005
lr_step: 30
max_height: 192
max_seq_len: 512
max_width: 672
min_height: 32
min_width: 32
micro_batchsize: 64
model_path: checkpoints_add
name: pix2tex-vit
num_layers: 4
num_tokens: 8000
optimizer: Adam
output_path: outputs
pad: false
pad_token: 0
patch_size: 16
sample_freq: 1000
save_freq: 5
scheduler: StepLR
seed: 42
encoder_structure: vit
temperature: 0.2
test_samples: 5
testbatchsize: 20
tokenizer: dataset/tokenizer.json
valbatches: 100
valdata: dataset/data/val.pkl
2 changes: 2 additions & 0 deletions pix2tex/model/settings/config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
backbone_layers:
- 2
- 3
Expand Down Expand Up @@ -45,6 +46,7 @@ sample_freq: 3000
save_freq: 5
scheduler: StepLR
seed: 42
encoder_structure: hybrid
temperature: 0.2
test_samples: 5
testbatchsize: 20
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ decoder_args:
heads: 8
num_tokens: 8000
max_seq_len: 1024
encoder_structure: hybrid

# Other
seed: 42
Expand Down
160 changes: 0 additions & 160 deletions pix2tex/models.py

This file was deleted.

1 change: 1 addition & 0 deletions pix2tex/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import *
56 changes: 56 additions & 0 deletions pix2tex/models/hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import torch.nn as nn

from timm.models.vision_transformer import VisionTransformer
from timm.models.vision_transformer_hybrid import HybridEmbed
from timm.models.resnetv2 import ResNetV2
from timm.models.layers import StdConv2dSame
from einops import repeat

class CustomVisionTransformer(VisionTransformer):
def __init__(self, img_size=224, patch_size=16, *args, **kwargs):
super(CustomVisionTransformer, self).__init__(img_size=img_size, patch_size=patch_size, *args, **kwargs)
self.height, self.width = img_size
self.patch_size = patch_size

def forward_features(self, x):
B, c, h, w = x.shape
x = self.patch_embed(x)

cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
h, w = h//self.patch_size, w//self.patch_size
pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
x += self.pos_embed[:, pos_emb_ind]
#x = x + self.pos_embed
x = self.pos_drop(x)

for blk in self.blocks:
x = blk(x)

x = self.norm(x)
return x


def get_encoder(args):
backbone = ResNetV2(
layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels,
preact=False, stem_type='same', conv_layer=StdConv2dSame)
min_patch_size = 2**(len(args.backbone_layers)+1)

def embed_layer(**x):
ps = x.pop('patch_size', min_patch_size)
assert ps % min_patch_size == 0 and ps >= min_patch_size, 'patch_size needs to be multiple of %i with current backbone configuration' % min_patch_size
return HybridEmbed(**x, patch_size=ps//min_patch_size, backbone=backbone)

encoder = CustomVisionTransformer(img_size=(args.max_height, args.max_width),
patch_size=args.patch_size,
in_chans=args.channels,
num_classes=0,
embed_dim=args.dim,
depth=args.encoder_depth,
num_heads=args.heads,
embed_layer=embed_layer
)
return encoder
Loading

0 comments on commit 06b7a9a

Please sign in to comment.