Skip to content

Commit

Permalink
Merge pull request #53 from nanoporetech/small_fixes
Browse files Browse the repository at this point in the history
Small fixes
  • Loading branch information
zhengzhenxian authored Sep 29, 2021
2 parents 9051547 + fd01d26 commit be99492
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 30 deletions.
64 changes: 38 additions & 26 deletions clair3/CallVariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,11 +686,8 @@ def output_from(
hetero_InsDel_length_tuples, hetero_InsDel_probabilities,
) = all_pro
maximum_probability = 0.0
maximum_loops = 4
reference_base, alternate_base = None, None
loop_index = 0
while (reference_base is None or alternate_base is None) and loop_index < maximum_loops:
loop_index += 1
while (reference_base is None or alternate_base is None):
maximum_probability = max(
homo_Ref_probability,
max(homo_SNP_probabilities),
Expand Down Expand Up @@ -723,32 +720,37 @@ def output_from(

if is_homo_SNP:
reference_base = reference_sequence[tensor_position_center]
idx = homo_SNP_probabilities.index(maximum_probability)
base1, base2 = homo_SNP_bases_from(homo_SNP_probabilities)
alternate_base = base1 if base1 != reference_base else base2
sorted_alt_bases, alternate_base = find_alt_base(alt_info_dict, alternate_base)
if alternate_base is None or alternate_base == reference_base:
homo_SNP_probabilities[idx] = 0
continue

elif is_hetero_SNP:
base1, base2 = hetero_SNP_bases_from(hetero_SNP_probabilities)
idx = hetero_SNP_probabilities.index(maximum_probability)
reference_base = reference_sequence[tensor_position_center]
is_multi = base1 != reference_base and base2 != reference_base
if is_multi:
sorted_alt_bases, _ = find_alt_base(alt_info_dict)
if len(sorted_alt_bases) == 0:
break
if len(sorted_alt_bases) < 2:
alternate_base = sorted_alt_bases[0]
hetero_SNP_probabilities[np.argmax(hetero_SNP_probabilities)] = 0.0
break
hetero_SNP_probabilities[idx] = 0
continue
alternate_base = ','.join(sorted_alt_bases[:2])
else:
alternate_base = base1 if base1 != reference_base else base2
sorted_alt_bases, alternate_base = find_alt_base(alt_info_dict, alternate_base)
if alternate_base is None or alternate_base == reference_base:
hetero_SNP_probabilities[idx] = 0
continue


elif is_homo_insertion:
variant_length = None
idx = homo_Ins_probabilities.index(maximum_probability)
if add_indel_length:
idx = homo_Ins_probabilities.index(maximum_probability)
variant_length = homo_Ins_lengths[idx]
insertion_bases = insertion_bases_using_alt_info_from(
alt_info_dict=alt_info_dict,
Expand All @@ -757,7 +759,8 @@ def output_from(

insertion_length = len(insertion_bases)
if insertion_length == 0:
break
homo_Ins_probabilities[idx] = 0
continue
reference_base = reference_sequence[tensor_position_center]
alternate_base = insertion_bases

Expand All @@ -776,25 +779,27 @@ def output_from(
)
insertion_length = len(insertion_bases)
if insertion_length == 0:
break
hetero_ACGT_Ins_probabilities[idx] = 0
continue
reference_base = reference_sequence[tensor_position_center]
alternate_base = insertion_bases

is_SNP_Ins_multi = hetero_Ins_base != reference_base
if is_SNP_Ins_multi:
sorted_alt_bases, _ = find_alt_base(alt_info_dict)
if len(sorted_alt_bases) == 0:
break
hetero_ACGT_Ins_probabilities[idx] = 0
continue
else:
alternate_base = "{},{}".format(sorted_alt_bases[0], alternate_base)

elif is_hetero_InsIns:
insertion_bases_list = []
idx = hetero_InsIns_probabilities.index(maximum_probability)
if add_indel_length:
idx = hetero_InsIns_probabilities.index(maximum_probability)
variant_length_1, variant_length_2 = hetero_InsIns_length_tuples[idx]
del hetero_InsIns_probabilities[idx]
del hetero_InsIns_length_tuples[idx]
# del hetero_InsIns_probabilities[idx]
# del hetero_InsIns_length_tuples[idx]

insertion_bases1 = insertion_bases_using_alt_info_from(
alt_info_dict=alt_info_dict,
Expand All @@ -819,7 +824,8 @@ def output_from(
return_multi=True
)
if len(insertion_bases_list) < 2:
break
hetero_InsIns_probabilities[idx] = 0
continue
insertion_bases, another_insertion_bases = insertion_bases_list

reference_base = reference_sequence[tensor_position_center]
Expand All @@ -830,12 +836,13 @@ def output_from(
if alternate_base_1 != alternate_base_2:
alternate_base = "{},{}".format(alternate_base_1, alternate_base_2)
else:
reference_base, alternate_base = None, None
hetero_InsIns_probabilities[idx] = 0
continue

elif is_homo_deletion:
variant_length = None
idx = homo_Del_probabilities.index(maximum_probability)
if add_indel_length:
idx = homo_Del_probabilities.index(maximum_probability)
variant_length = homo_Del_lengths[idx]

deletion_bases = deletion_bases_using_alt_info_from(
Expand All @@ -844,7 +851,8 @@ def output_from(
)
deletion_length = len(deletion_bases)
if deletion_length == 0:
break
homo_Del_probabilities[idx] = 0
continue
reference_base = reference_sequence[tensor_position_center] + deletion_bases
alternate_base = reference_base[0]

Expand All @@ -862,7 +870,8 @@ def output_from(
)
deletion_length = len(deletion_bases)
if deletion_length == 0:
break
hetero_ACGT_Del_probabilities[idx] = 0
continue
reference_base = reference_sequence[tensor_position_center] + deletion_bases
alternate_base = reference_base[0]

Expand All @@ -874,8 +883,8 @@ def output_from(

elif is_hetero_DelDel:
deletion_bases_list = []
idx = hetero_DelDel_probabilities.index(maximum_probability)
if add_indel_length:
idx = hetero_DelDel_probabilities.index(maximum_probability)
variant_length_1, variant_length_2 = sorted(hetero_DelDel_length_tuples[idx],
reverse=True) # longer deletion should be in first position
deletion_base1 = deletion_bases_using_alt_info_from(
Expand Down Expand Up @@ -903,7 +912,8 @@ def output_from(
)

if len(deletion_bases_list) < 2:
break
hetero_DelDel_probabilities[idx] = 0
continue

deletion_bases, deletion_bases1 = deletion_bases_list

Expand All @@ -918,12 +928,13 @@ def output_from(
):
alternate_base = "{},{}".format(alternate_base_1, alternate_base_2)
else:
reference_base, alternate_base = None, None
hetero_DelDel_probabilities[idx] = 0
continue

elif is_insertion_and_deletion:
variant_length_1, variant_length_2 = None, None
idx = hetero_InsDel_probabilities.index(maximum_probability)
if add_indel_length:
idx = hetero_InsDel_probabilities.index(maximum_probability)
variant_length_1, variant_length_2 = hetero_InsDel_length_tuples[idx]

insertion_bases = insertion_bases_using_alt_info_from(
Expand All @@ -939,7 +950,8 @@ def output_from(
deletion_length = len(deletion_bases)

if insertion_length == 0 or deletion_length == 0:
break
hetero_InsDel_probabilities[idx] = 0
continue
reference_base = reference_sequence[tensor_position_center] + deletion_bases
alternate_base = "{},{}".format(
reference_base[0],
Expand Down
10 changes: 8 additions & 2 deletions clair3/Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ def DataGenerator(x, data_size, shuffle_chunk_list, train_flag=True):
optimizer=optimizer
)
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, mode="min")
model_save_callbakck = tf.keras.callbacks.ModelCheckpoint(ochk_prefix + ".{epoch:02d}", period=1, save_weights_only=False)
model_save_callback = tf.keras.callbacks.ModelCheckpoint(ochk_prefix + ".{epoch:02d}", period=1, save_weights_only=False)
model_best_callback = tf.keras.callbacks.ModelCheckpoint("best_val_loss", monitor='val_loss', save_best_only=True, mode="min")
train_log_callback = tf.keras.callbacks.CSVLogger("training.log", separator='\t')

# Use first 20 element to initialize tensorflow model using graph mode
output = model(np.array(table_dataset_list[0].root.position_matrix[:20]))
Expand All @@ -228,11 +230,15 @@ def DataGenerator(x, data_size, shuffle_chunk_list, train_flag=True):
validate_dataset = validate_dataset if add_validation_dataset else None
if args.chkpnt_fn is not None:
model.load_weights(args.chkpnt_fn)
logging.info("[INFO] Starting from model {}".format(args.chkpnt_fn))

train_history = model.fit(x=train_dataset,
epochs=max_epoch,
validation_data=validate_dataset,
callbacks=[early_stop_callback, model_save_callbakck],
callbacks=[early_stop_callback,
model_save_callback,
model_best_callback,
train_log_callback],
verbose=1,
shuffle=False)

Expand Down
2 changes: 1 addition & 1 deletion preprocess/CreateTensorPileup.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def generate_tensor(pos, pileup_bases, reference_sequence, reference_start, refe

minimum_snp_af_for_candidate = minimum_snp_af_for_candidate if minimum_snp_af_for_candidate > 0 else param.min_af
minimum_snp_af_for_candidate = max(minimum_snp_af_for_candidate, param.min_af_dict[platform]) if fast_mode else minimum_snp_af_for_candidate
minimum_indel_af_for_candidate = max(minimum_indel_af_for_candidate, param.min_af_dict[platform]) if minimum_indel_af_for_candidate > 0 else param.min_af_dict[platform]
minimum_indel_af_for_candidate = minimum_indel_af_for_candidate if minimum_indel_af_for_candidate > 0 else param.min_af_dict[platform]

# check whether first non reference candidate in the first position
pass_af = len(pileup_list) and (pileup_list[0][0] != reference_base)
Expand Down
2 changes: 1 addition & 1 deletion run_clair3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ NC="\\033[0m"
ARGS=`getopt -o b:f:t:m:p:o:hv \
-l bam_fn:,ref_fn:,threads:,model_path:,platform:,output:,\
bed_fn::,vcf_fn::,ctg_name::,sample_name::,qual::,samtools::,python::,pypy::,parallel::,whatshap::,chunk_num::,chunk_size::,var_pct_full::,ref_pct_full::,\
snp_min_af::,indel_min_af::,pileup_model_prefix::,fa_model_preix::,fast_mode,gvcf,pileup_only,print_ref_calls,haploid_precise,haploid_sensitive,include_all_ctgs,no_phasing_for_fa,call_snp_only,help,version -n 'run_clair3.sh' -- "$@"`
snp_min_af::,indel_min_af::,pileup_model_prefix::,fa_model_prefix::,fast_mode,gvcf,pileup_only,print_ref_calls,haploid_precise,haploid_sensitive,include_all_ctgs,no_phasing_for_fa,call_snp_only,help,version -n 'run_clair3.sh' -- "$@"`

if [ $? != 0 ] ; then echo"No input. Terminating...">&2 ; exit 1 ; fi
eval set -- "${ARGS}"
Expand Down

0 comments on commit be99492

Please sign in to comment.