Skip to content

Commit 1b92175

Browse files
committed
add min_dimensions and respect max seq len
1 parent 005648b commit 1b92175

File tree

6 files changed

+28
-8
lines changed

6 files changed

+28
-8
lines changed

dataset/dataset.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class Im2LatexDataset:
5353
shuffle = True
5454
batchsize = 16
5555
max_dimensions = (1024, 512)
56+
min_dimensions = (32, 32)
57+
max_seq_len = 1024
5658
pad_token = "[PAD]"
5759
bos_token = "[BOS]"
5860
eos_token = "[EOS]"
@@ -61,7 +63,8 @@ class Im2LatexDataset:
6163
eos_token_id = 2
6264
transform = train_transform
6365

64-
def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, batchsize=16, max_dimensions=(1024, 512), pad=False, keep_smaller_batches=False, test=False):
66+
def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, batchsize=16, max_seq_len=1024,
67+
max_dimensions=(1024, 512), min_dimensions=(32, 32), pad=False, keep_smaller_batches=False, test=False):
6568
"""Generates a torch dataset from pairs of `equations` and `images`.
6669
6770
Args:
@@ -70,7 +73,9 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba
7073
tokenizer (str, optional): Path to saved tokenizer. Defaults to None.
7174
shuffle (bool, opitonal): Defaults to True.
7275
batchsize (int, optional): Defaults to 16.
76+
max_seq_len (int, optional): Defaults to 1024.
7377
max_dimensions (tuple(int, int), optional): Maximal dimensions the model can handle
78+
min_dimensions (tuple(int, int), optional): Minimal dimensions the model can handle
7479
pad (bool): Pad the images to `max_dimensions`. Defaults to False.
7580
keep_smaller_batches (bool): Whether to also return batches with smaller size than `batchsize`. Defaults to False.
7681
test (bool): Whether to use the test transformation or not. Defaults to False.
@@ -86,6 +91,7 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba
8691
self.shuffle = shuffle
8792
self.batchsize = batchsize
8893
self.max_dimensions = max_dimensions
94+
self.min_dimensions = min_dimensions
8995
self.pad = pad
9096
self.keep_smaller_batches = keep_smaller_batches
9197
self.test = test
@@ -94,7 +100,7 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba
94100
try:
95101
for i, im in tqdm(enumerate(self.images), total=len(self.images)):
96102
width, height = imagesize.get(im)
97-
if width <= max_dimensions[0] and height <= max_dimensions[1]:
103+
if min_dimensions[0] <= width <= max_dimensions[0] and min_dimensions[1] <= height <= max_dimensions[1]:
98104
self.data[(width, height)].append((eqs[self.indices[i]], im))
99105
except KeyboardInterrupt:
100106
pass
@@ -160,6 +166,9 @@ def prepare_data(self, batch):
160166
# pad with bos and eos token
161167
for k, p in zip(tok, [[self.bos_token_id, self.eos_token_id], [1, 1]]):
162168
tok[k] = pad_sequence([torch.LongTensor([p[0]]+x+[p[1]]) for x in tok[k]], batch_first=True, padding_value=self.pad_token_id)
169+
# check if sequence length is too long
170+
if self.max_seq_len < len(tok[0]):
171+
return next(self)
163172
try:
164173
images = torch.cat(images).float().unsqueeze(1)
165174
except RuntimeError:
@@ -196,14 +205,17 @@ def save(self, filename):
196205
pickle.dump(self, file)
197206

198207
def update(self, **kwargs):
199-
for k in ['batchsize', 'shuffle', 'pad', 'keep_smaller_batches', 'test']:
208+
for k in ['batchsize', 'shuffle', 'pad', 'keep_smaller_batches', 'test', 'max_seq_len']:
200209
if k in kwargs:
201210
setattr(self, k, kwargs[k])
202-
if 'max_dimensions' in kwargs:
203-
self.max_dimensions = kwargs['max_dimensions']
211+
if 'max_dimensions' in kwargs or 'min_dimensions' in kwargs:
212+
if 'max_dimensions' in kwargs:
213+
self.max_dimensions = kwargs['max_dimensions']
214+
if 'min_dimensions' in kwargs:
215+
self.min_dimensions = kwargs['min_dimensions']
204216
temp = {}
205217
for k in self.data:
206-
if 0 < k[0] <= self.max_dimensions[0] and 0 < k[1] <= self.max_dimensions[1]:
218+
if self.min_dimensions[0] <= k[0] <= self.max_dimensions[0] and self.min_dimensions[1] <= k[1] <= self.max_dimensions[1]:
207219
temp[k] = self.data[k]
208220
self.data = temp
209221
self._get_size()

settings/config.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ backbone_layers:
1616
max_dimensions:
1717
- 672
1818
- 192
19+
min_dimensions:
20+
- 32
21+
- 32
1922
max_height: 192
2023
max_seq_len: 1024
2124
max_width: 672

settings/debug.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ betas: [0.9, 0.999]
2727
# Parameters for model architectures
2828
max_width: 128
2929
max_height: 128
30+
min_width: 32
31+
min_height: 32
3032
channels: 1
3133
patch_size: 32
3234
# Encoder / Decoder

settings/default.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ betas: [0.9, 0.999]
2626
# Parameters for model architectures
2727
max_width: 672
2828
max_height: 192
29+
min_width: 96
30+
min_height: 32
2931
channels: 1
3032
patch_size: 16
3133
# Encoder / Decoder
@@ -41,7 +43,7 @@ decoder_args:
4143
rel_pos_bias: false
4244
heads: 8
4345
num_tokens: 8000
44-
max_seq_len: 1024
46+
max_seq_len: 512
4547

4648
# Other
4749
seed: 42

train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def save_models(e):
5656
if args.wandb:
5757
wandb.log({'train/loss': loss.item()})
5858
if (i+1) % args.sample_freq == 0:
59-
evaluate(model, valdataloader, args, num_batches=args.valbatches, name='val')
59+
evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
6060
if (e+1) % args.save_freq == 0:
6161
save_models(e)
6262
if args.wandb:

utils/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def parse_args(args, **kwargs):
5454
args.wandb = not kwargs.debug and not args.debug
5555
args.device = 'cuda' if torch.cuda.is_available() and not kwargs.no_cuda else 'cpu'
5656
args.max_dimensions = [args.max_width, args.max_height]
57+
args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)]
5758
if 'decoder_args' not in args or args.decoder_args is None:
5859
args.decoder_args = {}
5960
if 'model_path' in args:

0 commit comments

Comments
 (0)