@@ -53,6 +53,8 @@ class Im2LatexDataset:
53
53
shuffle = True
54
54
batchsize = 16
55
55
max_dimensions = (1024 , 512 )
56
+ min_dimensions = (32 , 32 )
57
+ max_seq_len = 1024
56
58
pad_token = "[PAD]"
57
59
bos_token = "[BOS]"
58
60
eos_token = "[EOS]"
@@ -61,7 +63,8 @@ class Im2LatexDataset:
61
63
eos_token_id = 2
62
64
transform = train_transform
63
65
64
- def __init__ (self , equations = None , images = None , tokenizer = None , shuffle = True , batchsize = 16 , max_dimensions = (1024 , 512 ), pad = False , keep_smaller_batches = False , test = False ):
66
+ def __init__ (self , equations = None , images = None , tokenizer = None , shuffle = True , batchsize = 16 , max_seq_len = 1024 ,
67
+ max_dimensions = (1024 , 512 ), min_dimensions = (32 , 32 ), pad = False , keep_smaller_batches = False , test = False ):
65
68
"""Generates a torch dataset from pairs of `equations` and `images`.
66
69
67
70
Args:
@@ -70,7 +73,9 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba
70
73
tokenizer (str, optional): Path to saved tokenizer. Defaults to None.
71
74
shuffle (bool, opitonal): Defaults to True.
72
75
batchsize (int, optional): Defaults to 16.
76
+ max_seq_len (int, optional): Defaults to 1024.
73
77
max_dimensions (tuple(int, int), optional): Maximal dimensions the model can handle
78
+ min_dimensions (tuple(int, int), optional): Minimal dimensions the model can handle
74
79
pad (bool): Pad the images to `max_dimensions`. Defaults to False.
75
80
keep_smaller_batches (bool): Whether to also return batches with smaller size than `batchsize`. Defaults to False.
76
81
test (bool): Whether to use the test transformation or not. Defaults to False.
@@ -86,6 +91,7 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba
86
91
self .shuffle = shuffle
87
92
self .batchsize = batchsize
88
93
self .max_dimensions = max_dimensions
94
+ self .min_dimensions = min_dimensions
89
95
self .pad = pad
90
96
self .keep_smaller_batches = keep_smaller_batches
91
97
self .test = test
@@ -94,7 +100,7 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba
94
100
try :
95
101
for i , im in tqdm (enumerate (self .images ), total = len (self .images )):
96
102
width , height = imagesize .get (im )
97
- if width <= max_dimensions [0 ] and height <= max_dimensions [1 ]:
103
+ if min_dimensions [ 0 ] <= width <= max_dimensions [0 ] and min_dimensions [ 1 ] <= height <= max_dimensions [1 ]:
98
104
self .data [(width , height )].append ((eqs [self .indices [i ]], im ))
99
105
except KeyboardInterrupt :
100
106
pass
@@ -160,6 +166,9 @@ def prepare_data(self, batch):
160
166
# pad with bos and eos token
161
167
for k , p in zip (tok , [[self .bos_token_id , self .eos_token_id ], [1 , 1 ]]):
162
168
tok [k ] = pad_sequence ([torch .LongTensor ([p [0 ]]+ x + [p [1 ]]) for x in tok [k ]], batch_first = True , padding_value = self .pad_token_id )
169
+ # check if sequence length is too long
170
+ if self .max_seq_len < len (tok [0 ]):
171
+ return next (self )
163
172
try :
164
173
images = torch .cat (images ).float ().unsqueeze (1 )
165
174
except RuntimeError :
@@ -196,14 +205,17 @@ def save(self, filename):
196
205
pickle .dump (self , file )
197
206
198
207
def update (self , ** kwargs ):
199
- for k in ['batchsize' , 'shuffle' , 'pad' , 'keep_smaller_batches' , 'test' ]:
208
+ for k in ['batchsize' , 'shuffle' , 'pad' , 'keep_smaller_batches' , 'test' , 'max_seq_len' ]:
200
209
if k in kwargs :
201
210
setattr (self , k , kwargs [k ])
202
- if 'max_dimensions' in kwargs :
203
- self .max_dimensions = kwargs ['max_dimensions' ]
211
+ if 'max_dimensions' in kwargs or 'min_dimensions' in kwargs :
212
+ if 'max_dimensions' in kwargs :
213
+ self .max_dimensions = kwargs ['max_dimensions' ]
214
+ if 'min_dimensions' in kwargs :
215
+ self .min_dimensions = kwargs ['min_dimensions' ]
204
216
temp = {}
205
217
for k in self .data :
206
- if 0 < k [0 ] <= self .max_dimensions [0 ] and 0 < k [1 ] <= self .max_dimensions [1 ]:
218
+ if self . min_dimensions [ 0 ] <= k [0 ] <= self .max_dimensions [0 ] and self . min_dimensions [ 1 ] <= k [1 ] <= self .max_dimensions [1 ]:
207
219
temp [k ] = self .data [k ]
208
220
self .data = temp
209
221
self ._get_size ()
0 commit comments