Skip to content

Commit 95a9fae

Browse files
committed
fix in dataloader. To long sequences are dropped lukas-blecher#8
1 parent 13434d6 commit 95a9fae

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

dataset/dataset.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@
2626
[alb.ShiftScaleRotate(shift_limit=0, scale_limit=(-.15, 0), rotate_limit=1, border_mode=0, interpolation=3,
2727
value=[255, 255, 255], p=1),
2828
alb.GridDistortion(distort_limit=0.1, border_mode=0, interpolation=3, value=[255, 255, 255], p=.5)], p=.15),
29-
alb.InvertImg(p=.15),
29+
#alb.InvertImg(p=.15),
3030
alb.RGBShift(r_shift_limit=15, g_shift_limit=15,
3131
b_shift_limit=15, p=0.3),
3232
alb.GaussNoise(10, p=.2),
3333
alb.RandomBrightnessContrast(.05, (-.2, 0), True, p=0.2),
34-
alb.JpegCompression(95, p=.5),
34+
alb.JpegCompression(95, p=.3),
3535
alb.ToGray(always_apply=True),
3636
alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
3737
# alb.Sharpen()
@@ -150,6 +150,13 @@ def prepare_data(self, batch):
150150
"""
151151

152152
eqs, ims = batch.T
153+
tok = self.tokenizer(list(eqs), return_token_type_ids=False)
154+
# pad with bos and eos token
155+
for k, p in zip(tok, [[self.bos_token_id, self.eos_token_id], [1, 1]]):
156+
tok[k] = pad_sequence([torch.LongTensor([p[0]]+x+[p[1]]) for x in tok[k]], batch_first=True, padding_value=self.pad_token_id)
157+
# check if sequence length is too long
158+
if self.max_seq_len < tok['attention_mask'].shape[1]:
159+
return next(self)
153160
images = []
154161
for path in list(ims):
155162
im = cv2.imread(path)
@@ -162,13 +169,6 @@ def prepare_data(self, batch):
162169
if np.random.random() < .04:
163170
im[im != 255] = 0
164171
images.append(self.transform(image=im)['image'][:1])
165-
tok = self.tokenizer(list(eqs), return_token_type_ids=False)
166-
# pad with bos and eos token
167-
for k, p in zip(tok, [[self.bos_token_id, self.eos_token_id], [1, 1]]):
168-
tok[k] = pad_sequence([torch.LongTensor([p[0]]+x+[p[1]]) for x in tok[k]], batch_first=True, padding_value=self.pad_token_id)
169-
# check if sequence length is too long
170-
if self.max_seq_len < len(tok[0]):
171-
return next(self)
172172
try:
173173
images = torch.cat(images).float().unsqueeze(1)
174174
except RuntimeError:

train.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def save_models(e):
5151
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
5252
opt.step()
5353
scheduler.step()
54-
5554
dset.set_description('Loss: %.4f' % loss.item())
5655
if args.wandb:
5756
wandb.log({'train/loss': loss.item()})
@@ -79,8 +78,7 @@ def save_models(e):
7978
parsed_args = parser.parse_args()
8079
with parsed_args.config as f:
8180
params = yaml.load(f, Loader=yaml.FullLoader)
82-
args = parse_args(Munch(params))
83-
81+
args = parse_args(Munch(params), **vars(parsed_args))
8482
logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING)
8583
seed_everything(args.seed)
8684
if args.wandb:

0 commit comments

Comments
 (0)