Skip to content

Commit

Permalink
add replace N token and add pad token flags in hg38 pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
exnx committed Aug 11, 2023
1 parent 7016972 commit e152efa
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
2 changes: 2 additions & 0 deletions configs/experiment/hg38/hg38_hyena.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ dataset:
rc_aug: false
num_workers: 12
use_fixed_len_val: false # placing a fixed length val here, but it's really the test
replace_N_token: false # replace N (uncertain token) with pad tokens in dataloader
pad_interval: false # handle uncertain tokens within the FastaInteral class

scheduler:
t_in_epochs: False
Expand Down
25 changes: 20 additions & 5 deletions src/dataloaders/datasets/hg38_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def __init__(
# max_length = None,
return_seq_indices = False,
shift_augs = None,
rc_aug = False
rc_aug = False,
pad_interval = False,
):
fasta_file = Path(fasta_file)
assert fasta_file.exists(), 'path to fasta file must exist'
Expand All @@ -56,6 +57,7 @@ def __init__(
# self.max_length = max_length # -1 for adding sos or eos token
self.shift_augs = shift_augs
self.rc_aug = rc_aug
self.pad_interval = pad_interval

# calc len of each chromosome in fasta file, store in dict
self.chr_lens = {}
Expand Down Expand Up @@ -87,6 +89,8 @@ def __call__(self, chr_name, start, end, max_length, return_augs = False):
start += rand_shift
end += rand_shift

left_padding = right_padding = 0

# checks if not enough sequence to fill up the start to end
if interval_length < max_length:
extra_seq = max_length - interval_length
Expand All @@ -98,9 +102,11 @@ def __call__(self, chr_name, start, end, max_length, return_augs = False):
end += extra_right_seq

if start < 0:
left_padding = -start
start = 0

if end > chromosome_length:
right_padding = end - chromosome_length
end = chromosome_length

# Added support! need to allow shorter seqs
Expand All @@ -112,6 +118,9 @@ def __call__(self, chr_name, start, end, max_length, return_augs = False):
if self.rc_aug and coin_flip():
seq = string_reverse_complement(seq)

if self.pad_interval:
seq = ('.' * left_padding) + seq + ('.' * right_padding)

return seq

class HG38Dataset(torch.utils.data.Dataset):
Expand All @@ -134,7 +143,9 @@ def __init__(
return_seq_indices=False,
shift_augs=None,
rc_aug=False,
return_augs=False
return_augs=False,
replace_N_token=False, # replace N token with pad token
pad_interval = False, # options for different padding
):

self.max_length = max_length
Expand All @@ -143,6 +154,8 @@ def __init__(
self.tokenizer = tokenizer
self.return_augs = return_augs
self.add_eos = add_eos
self.replace_N_token = replace_N_token
self.pad_interval = pad_interval

bed_path = Path(bed_file)
assert bed_path.exists(), 'path to .bed file must exist'
Expand All @@ -157,7 +170,8 @@ def __init__(
# max_length = max_length,
return_seq_indices = return_seq_indices,
shift_augs = shift_augs,
rc_aug = rc_aug
rc_aug = rc_aug,
pad_interval = pad_interval,
)

def __len__(self):
Expand Down Expand Up @@ -206,8 +220,9 @@ def __getitem__(self, idx):
# convert to tensor
seq = torch.LongTensor(seq) # hack, remove the initial cls tokens for now

# replace N token with a pad token, so we can ignore it in the loss
seq = self.replace_value(seq, self.tokenizer._vocab_str_to_int['N'], self.tokenizer.pad_token_id)
if self.replace_N_token:
# replace N token with a pad token, so we can ignore it in the loss
seq = self.replace_value(seq, self.tokenizer._vocab_str_to_int['N'], self.tokenizer.pad_token_id)

data = seq[:-1].clone() # remove eos
target = seq[1:].clone() # offset by 1, includes eos
Expand Down
8 changes: 6 additions & 2 deletions src/dataloaders/genomics.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, bed_file, fasta_file, tokenizer_name=None, dataset_config_nam
max_length_val=None, max_length_test=None, val_ratio=0.0005, val_split_seed=2357, use_fixed_len_val=False,
add_eos=True, detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1,
shuffle=False, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False,
fast_forward_epochs=None, fast_forward_batches=None,
fast_forward_epochs=None, fast_forward_batches=None, replace_N_token=False, pad_interval=False,
*args, **kwargs):
self.dataset_config_name = dataset_config_name
self.tokenizer_name = tokenizer_name
Expand All @@ -71,6 +71,8 @@ def __init__(self, bed_file, fasta_file, tokenizer_name=None, dataset_config_nam
self.bed_file = bed_file
self.fasta_file = fasta_file
self.use_fixed_len_val = use_fixed_len_val
self.replace_N_token = replace_N_token
self.pad_interval = pad_interval

# handle if file paths are None (default paths)
if self.bed_file is None:
Expand Down Expand Up @@ -133,7 +135,9 @@ def init_datasets(self):
return_seq_indices=False,
shift_augs=None,
rc_aug=self.rc_aug,
return_augs=False)
return_augs=False,
replace_N_token=self.replace_N_token,
pad_interval=self.pad_interval)
for split, max_len in zip(['train', 'valid', 'test'], [self.max_length, self.max_length_val, self.max_length_test])
]

Expand Down

0 comments on commit e152efa

Please sign in to comment.