26
26
[alb .ShiftScaleRotate (shift_limit = 0 , scale_limit = (- .15 , 0 ), rotate_limit = 1 , border_mode = 0 , interpolation = 3 ,
27
27
value = [255 , 255 , 255 ], p = 1 ),
28
28
alb .GridDistortion (distort_limit = 0.1 , border_mode = 0 , interpolation = 3 , value = [255 , 255 , 255 ], p = .5 )], p = .15 ),
29
- alb .InvertImg (p = .15 ),
29
+ # alb.InvertImg(p=.15),
30
30
alb .RGBShift (r_shift_limit = 15 , g_shift_limit = 15 ,
31
31
b_shift_limit = 15 , p = 0.3 ),
32
32
alb .GaussNoise (10 , p = .2 ),
33
33
alb .RandomBrightnessContrast (.05 , (- .2 , 0 ), True , p = 0.2 ),
34
- alb .JpegCompression (95 , p = .5 ),
34
+ alb .JpegCompression (95 , p = .3 ),
35
35
alb .ToGray (always_apply = True ),
36
36
alb .Normalize ((0.7931 , 0.7931 , 0.7931 ), (0.1738 , 0.1738 , 0.1738 )),
37
37
# alb.Sharpen()
@@ -150,6 +150,13 @@ def prepare_data(self, batch):
150
150
"""
151
151
152
152
eqs , ims = batch .T
153
+ tok = self .tokenizer (list (eqs ), return_token_type_ids = False )
154
+ # pad with bos and eos token
155
+ for k , p in zip (tok , [[self .bos_token_id , self .eos_token_id ], [1 , 1 ]]):
156
+ tok [k ] = pad_sequence ([torch .LongTensor ([p [0 ]]+ x + [p [1 ]]) for x in tok [k ]], batch_first = True , padding_value = self .pad_token_id )
157
+ # check if sequence length is too long
158
+ if self .max_seq_len < tok ['attention_mask' ].shape [1 ]:
159
+ return next (self )
153
160
images = []
154
161
for path in list (ims ):
155
162
im = cv2 .imread (path )
@@ -162,13 +169,6 @@ def prepare_data(self, batch):
162
169
if np .random .random () < .04 :
163
170
im [im != 255 ] = 0
164
171
images .append (self .transform (image = im )['image' ][:1 ])
165
- tok = self .tokenizer (list (eqs ), return_token_type_ids = False )
166
- # pad with bos and eos token
167
- for k , p in zip (tok , [[self .bos_token_id , self .eos_token_id ], [1 , 1 ]]):
168
- tok [k ] = pad_sequence ([torch .LongTensor ([p [0 ]]+ x + [p [1 ]]) for x in tok [k ]], batch_first = True , padding_value = self .pad_token_id )
169
- # check if sequence length is too long
170
- if self .max_seq_len < len (tok [0 ]):
171
- return next (self )
172
172
try :
173
173
images = torch .cat (images ).float ().unsqueeze (1 )
174
174
except RuntimeError :
0 commit comments