From e0cb47081cc6bde2c2613411efd7bff941e9f572 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 26 Feb 2022 13:56:22 -0800 Subject: [PATCH] bug fixes for head finetuning validation --- setup.py | 2 +- tf_bind_transformer/data_bigwig.py | 14 +++++++++++--- tf_bind_transformer/training_utils_bigwig.py | 9 ++++++--- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index b561c94..974eb43 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'tf-bind-transformer', packages = find_packages(exclude=[]), - version = '0.0.111', + version = '0.0.116', license='MIT', description = 'Transformer for Transcription Factor Binding', author = 'Phil Wang', diff --git a/tf_bind_transformer/data_bigwig.py b/tf_bind_transformer/data_bigwig.py index 3bd58ff..c3b19e6 100644 --- a/tf_bind_transformer/data_bigwig.py +++ b/tf_bind_transformer/data_bigwig.py @@ -113,7 +113,7 @@ def __init__( filtered_exp_ids = set(annot_df.get_column('column_1').to_list()) filtered_out_exp_ids = exp_ids - filtered_exp_ids - print(f'{', '.join(only_ref)} - {len(filtered_out_exp_ids)} experiments filtered out by lack of transcription factor fastas', filtered_out_exp_ids) + print(f'{", ".join(only_ref)} - {len(filtered_out_exp_ids)} experiments filtered out by lack of transcription factor fastas', filtered_out_exp_ids) # filter dataset by inclusion and exclusion list of targets # ( intersect ) subtract @@ -270,7 +270,7 @@ def __init__( # bigwigs - self.bigwigs = [pyBigWig.open(str(bigwig_folder / f'{str(i)}.bw')) for i in self.annot.get_column("column_1")] + self.bigwigs = [(str(i), pyBigWig.open(str(bigwig_folder / f'{str(i)}.bw'))) for i in self.annot.get_column("column_1")] self.downsample_factor = downsample_factor self.target_length = target_length @@ -294,7 +294,15 @@ def __getitem__(self, ind): # calculate bigwig # properly downsample and then crop - all_bw_values = [np.array(bw.values(chr_name, begin, end)) for bw in self.bigwigs] + all_bw_values = [] + + for bw_path, bw in self.bigwigs: + try: + bw_values = bw.values(chr_name, begin, end) + all_bw_values.append(bw_values) + except: + print(f'hitting invalid range for {bw_path} - ({chr_name}, {begin}, {end})') + exit() output = np.stack(all_bw_values, axis = -1) output = output.reshape((-1, self.downsample_factor, self.ntargets)) diff --git a/tf_bind_transformer/training_utils_bigwig.py b/tf_bind_transformer/training_utils_bigwig.py index 7a996f2..57b5868 100644 --- a/tf_bind_transformer/training_utils_bigwig.py +++ b/tf_bind_transformer/training_utils_bigwig.py @@ -310,7 +310,7 @@ def forward( if exists(self.train_mouse_head_dl): for _ in range(grad_accum_every): - seq, target = next(self.train_mouse_house_dl) + seq, target = next(self.train_mouse_head_dl) seq, target = seq.cuda(), target.cuda() loss = self.model( @@ -393,7 +393,7 @@ def forward( pred = self.model(seq, head = 'human') valid_loss = self.model.loss_fn(pred, target) - valid_corr_coef = pearson_corr_coef(pred, target) + valid_corr_coef = pearson_corr_coef(pred, target).mean() log = accum_log(log, { 'human_head_valid_loss': valid_loss.item() / grad_accum_every, @@ -404,11 +404,14 @@ def forward( if exists(self.valid_mouse_head_dl): for _ in range(grad_accum_every): - seq, target = next(self.valid_mouse_house_dl) + seq, target = next(self.valid_mouse_head_dl) seq, target = seq.cuda(), target.cuda() pred = self.model(seq, head = 'mouse') + valid_loss = self.model.loss_fn(pred, target) + valid_corr_coef = pearson_corr_coef(pred, target).mean() + log = accum_log(log, { 'mouse_head_valid_loss': valid_loss.item() / grad_accum_every, 'mouse_head_valid_corr_coef': valid_corr_coef.item() / grad_accum_every