Skip to content

Commit

Permalink
pretraing bug fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
logicwong committed Jul 29, 2022
1 parent b8fb736 commit 7aed680
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions data/pretrain_data/unify_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,11 +627,10 @@ def collater(self, samples, pad_to_length=None):
for sample_tuple in samples:
samples_v1 += sample_tuple[0]
samples_v2 += sample_tuple[1]
if samples_v2 == []:
samples_v2 += self.process_pure_text(0) if self.pure_text_dataset else []
samples_v2 += self.process_pure_image(0) if self.pure_image_dataset else []
samples_v2 += self.process_detection(0) if self.detection_dataset else []

res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
res_v2 = collate(samples_v2, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
return res_v1, res_v2
if samples_v2 != []:
res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
res_v2 = collate(samples_v2, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
return res_v1, res_v2
else:
res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
return res_v1

0 comments on commit 7aed680

Please sign in to comment.