diff --git a/scripts/chng_between_sample.py b/scripts/chng_between_sample.py index 0816d7d6..4e3dd86d 100644 --- a/scripts/chng_between_sample.py +++ b/scripts/chng_between_sample.py @@ -140,8 +140,13 @@ class result: continue pair = natsorted([srr,read_file_to_pair_dict[srr]]) if 'restrict' in res.contig_name: - pair_to_res[tuple(pair)].append((res.final_ani)) - pair_to_res_naive[tuple(pair)].append(res.naive_ani) + if boxplot_covs: + pair_to_res[tuple(pair)].append((res.eff_cov)) + pair_to_res_naive[tuple(pair)].append(res.eff_cov) + + else: + pair_to_res[tuple(pair)].append((res.final_ani)) + pair_to_res_naive[tuple(pair)].append(res.naive_ani) if tuple(pair) in res_used_pairs: continue if read_file_to_status[srr] == 'Case': @@ -150,21 +155,26 @@ class result: case_res_n.append(res.naive_ani) else: - case_res.append(res.eff_cov) - case_res_n.append(res.eff_cov) + case_res.append(res.median_cov) + case_res_n.append(res.median_cov) else: if not boxplot_covs: control_res.append(res.final_ani) control_res_n.append(res.naive_ani) else: - control_res.append(res.eff_cov) - control_res_n.append(res.eff_cov) + control_res.append(res.median_cov) + control_res_n.append(res.median_cov) res_used_pairs.add(tuple(pair)) if 'globo' in res.contig_name: - pair_to_res_globo[tuple(pair)].append((res.final_ani)) - pair_to_res_naive_globo[tuple(pair)].append(res.naive_ani) + if boxplot_covs: + pair_to_res_globo[tuple(pair)].append((res.eff_cov)) + pair_to_res_naive_globo[tuple(pair)].append(res.eff_cov) + + else: + pair_to_res_globo[tuple(pair)].append((res.final_ani)) + pair_to_res_naive_globo[tuple(pair)].append(res.naive_ani) if tuple(pair) in globo_used_pairs: continue @@ -174,16 +184,16 @@ class result: case_globo_n.append(res.naive_ani) else: - case_globo.append(res.eff_cov) - case_globo_n.append(res.eff_cov) + case_globo.append(res.median_cov) + case_globo_n.append(res.median_cov) else: if not boxplot_covs: control_globo.append(res.final_ani) control_globo_n.append(res.naive_ani) else: - control_globo.append(res.eff_cov) - control_globo_n.append(res.eff_cov) + control_globo.append(res.median_cov) + control_globo_n.append(res.median_cov) globo_used_pairs.add(tuple(pair)) @@ -191,6 +201,7 @@ class result: print(len(globo_used_pairs)) #print(len(globo_used_pairs)) print(pair_to_res) +print(pair_to_res_globo) sc = np.array(list(pair_to_res.values())) sc2 = np.array(list(pair_to_res_naive.values())) @@ -229,10 +240,10 @@ class result: ax[1][0].text(.05, .99, 'M. globosa', ha='left', va='top', transform=ax[1][0].transAxes) -ax[0][0].text(.55, .25, rf"$R^2$ = {round(sr.rvalue**2,3)}", ha='left', va='top', transform=ax[0][0].transAxes) -ax[0][1].text(.55, .25, rf"$R^2$ = {round(nr.rvalue**2,3)}", ha='left', va='top', transform=ax[0][1].transAxes) -ax[1][1].text(.55, .25, rf"$R^2$ = {round(ng.rvalue**2,3)}", ha='left', va='top', transform=ax[1][1].transAxes) -ax[1][0].text(.55, .25, rf"$R^2$ = {round(sg.rvalue**2,3)}", ha='left', va='top', transform=ax[1][0].transAxes) +ax[0][0].text(.55, .25, rf"$R$ = {round(sr.rvalue**1,3)}", ha='left', va='top', transform=ax[0][0].transAxes) +ax[0][1].text(.55, .25, rf"$R$ = {round(nr.rvalue**1,3)}", ha='left', va='top', transform=ax[0][1].transAxes) +ax[1][1].text(.55, .25, rf"$R$ = {round(ng.rvalue**1,3)}", ha='left', va='top', transform=ax[1][1].transAxes) +ax[1][0].text(.55, .25, rf"$R$ = {round(sg.rvalue**1,3)}", ha='left', va='top', transform=ax[1][0].transAxes) for a in ax: for b in a: @@ -307,6 +318,17 @@ def add_stat_annotation(ax, bp, pval): control_res_n = [x[0]/2 + x[1]/2 for (a,x) in pair_to_res_naive.items() if read_file_to_status[a[0]] == 'Control'] control_globo_n = [x[0]/2 + x[1]/2 for (a,x) in pair_to_res_naive_globo.items() if read_file_to_status[a[0]] == 'Control'] +print(len([x for x in case_res if x > 95]), ': num restrica > 95 case') +print(len([x for x in control_res if x > 95]), ': num restrica > 95 cont') +print(len([x for x in case_globo if x > 95]), ': num globo > 95 case') +print(len([x for x in control_globo if x > 95]), ': num globo > 95 cont') + +print(len([x for x in case_res if x > 90]), ': num restrica > 90 case') +print(len([x for x in control_res if x > 90]), ': num restrica > 90 cont') +print(len([x for x in case_globo if x > 90]), ': num globo > 90 case') +print(len([x for x in control_globo if x > 90]), ': num globo > 90 cont') + + # Add the boxplots to the axes bp1 = ax[0][0].boxplot([case_res, control_res], patch_artist=True, boxprops=dict(facecolor=cmap[0]), labels = ['Case', 'Control']) @@ -359,6 +381,7 @@ def add_stat_annotation(ax, bp, pval): for bp in bps: for median in bp['medians']: median.set_color('black') + print(median.get_ydata()) for a in ax: for b in a: diff --git a/scripts/diagonal_ani_nn.py b/scripts/diagonal_ani_nn.py index 09822af1..c9651f9e 100644 --- a/scripts/diagonal_ani_nn.py +++ b/scripts/diagonal_ani_nn.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from natsort import natsorted cmap = sns.color_palette("muted") -plt_diag = True +plt_diag = False np.random.seed(0) def rand_jitter(arr): @@ -137,8 +137,8 @@ class result: true_results[-1].append(res) -fig, ax = plt.subplots(1, 3, figsize = (16* cm , 7 * cm), sharey = True, sharex = True) -s = 8 +fig, ax = plt.subplots(1, 3, figsize = (16* cm , 5.5 * cm), sharey = True, sharex = True) +s = 7 for (i,name) in enumerate(['Illumina', 'Nanopore-old', 'PacBio']): for j in range(1): @@ -206,7 +206,7 @@ class result: ax[i].set_xlabel("True containment ANI") if j == 1: - ax[i].scatter(x,y,s = s, color = cmap[2], alpha = 0.5, label = rf"c = 1000, $R^2$ = {round(good_lr.rvalue**2,3)}") + ax[i].scatter(x,y,s = s, color = cmap[2], alpha = 0.5, label = rf"c = 1000, $R$ = {round(good_lr.rvalue**1,3)}") else: if i == 0: ax[i].set_title("Illumina", fontsize = plt.rcParams['font.size']) @@ -215,11 +215,11 @@ class result: elif i == 2: ax[i].set_title("PacBio", fontsize = plt.rcParams['font.size']) - ax[i].scatter(x,y,s = s, color = cmap[0], alpha = 0.5, label = rf"sylph, $R^2$ = {round(good_lr.rvalue**2,3)}") - ax[i].scatter(x,z, s= s, color = cmap[3], alpha = 0.5, label = rf"Naive, $R^2$ = {round(naive_lr.rvalue**2,3)}") + ax[i].scatter(x,y,s = s, color = cmap[0], alpha = 0.5, label = rf"sylph $R$ = {round(good_lr.rvalue**1,2)}") + ax[i].scatter(x,z, s= s, color = cmap[3], alpha = 0.5, label = rf"Naive $R$ = {round(naive_lr.rvalue**1,2)}") ax[i].plot([90,100],[90,100],'--', c = 'black') print('covered: ' + str(covered/total) + 'total: ' + str(total)) - ax[i].set_ylim([85,100]) + ax[i].set_ylim([85,102]) if i == 0: ax[i].set_ylabel("Estimated ANI") diff --git a/scripts/manhat.py b/scripts/manhat.py index 0b5b2d70..fa56b15b 100644 --- a/scripts/manhat.py +++ b/scripts/manhat.py @@ -2,17 +2,28 @@ import scipy.stats as stats import statsmodels.api as sm import statsmodels.stats as ss +import plotly.express as px +import plotly.graph_objects as go from scipy.stats.distributions import norm,uniform +interactive = False import matplotlib.cm as cm import numpy as np cmap = plt.get_cmap('tab20') plt.set_cmap(cmap) q = 0.05 +LIMIT = 0.0018366409299248078 def fdr(p_vals, alpha ): not_used, pvals = ss.multitest.fdrcorrection(p_vals, alpha = alpha) + limit = 100 + for i in range(len(not_used)): + if not_used[i] == False: + if limit > p_vals[i]: + limit = p_vals[i] + continue + print("LIMIT OF DETECT", limit) return pvals @@ -21,7 +32,7 @@ def fdr(p_vals, alpha ): plt.rcParams.update({'font.size': 7}) plt.rcParams.update({'figure.autolayout': True}) plt.rcParams.update({'font.family':'arial'}) -fig, ax = plt.subplots(figsize = (6.5* cm , 6.5 * cm)) +fig, ax = plt.subplots(figsize = (5.5* cm , 4.5 * cm)) pvals = [] c = [] @@ -88,7 +99,7 @@ def fdr(p_vals, alpha ): print(gn_to_rep[s_and_mag_pair[i][1]]) #exit() -ax.scatter(-np.log10(qq[0][0]), -np.log10(qq[0][1]), s = 4, label = 'Arbitrary species representative') +ax.scatter(-np.log10(qq[0][0]), -np.log10(qq[0][1]), s = 4, label = 'Species\nrepresentative') #ax.plot([np.min(-np.log10(qq[0][0])), np.max(-np.log10(qq[0][0]))], [np.min(-np.log10(qq[0][1])), np.max(-np.log10(qq[0][1]))], 'r-') ax.plot([0,5],[0,5], 'r-') #plt.title('Q-Q plot of -log10(p-values) against uniform distribution') @@ -113,25 +124,44 @@ def fdr(p_vals, alpha ): cs = [] mag_pval_list = [] +rep_list = [] seen_cs = set() for line in open(order_file,'r'): mag = line.split('.fa')[0].rstrip() if mag in mag_to_pval: - pvals.append(mag_to_pval[mag]) - seen_cs.add(mag_to_c[mag]) - cs.append(len(seen_cs)%20) - mag_pval_list.append(mag) + if interactive: + if hash(mag) % 10 == 0: + if mag in mag_to_pval: + pvals.append(mag_to_pval[mag]) + seen_cs.add(mag_to_c[mag]) + cs.append(len(seen_cs)%20) + mag_pval_list.append(mag) + rep_list.append(gn_to_rep[mag]) + else: + pvals.append(mag_to_pval[mag]) + seen_cs.add(mag_to_c[mag]) + cs.append(len(seen_cs)%20) + mag_pval_list.append(mag) + fig, ax = plt.subplots(figsize = (16* cm , 6 * cm)) + +# Instead of creating the scatter plot using Matplotlib, use Plotly +if interactive: + fig = go.Figure() + agatho_qvals = [] agatho_rep = 'MGYG000002492' #print(pvals) qvals = fdr(np.power(10,pvals), q) -qvals = np.log10(qvals) +#I misunderstood qvals, it's a diff procedure and not benjamini hochbergq +#qvals = np.power(10,pvals) +#qvals = np.log10(qvals) +qvals = pvals seen_reps = set() for (i,pval) in enumerate(qvals): m = mag_pval_list[i] - if pval < np.log10(q): + if pval < np.log10(LIMIT): rep = gn_to_rep[m] if rep not in seen_reps: s = f"{10**pvals[i]},{10**pval},{rep},{gn_to_mag_rep[m]},{m},{mag_to_effect[mag_pval_list[i]]}" @@ -144,29 +174,59 @@ def fdr(p_vals, alpha ): size = 3 else: size = 0.3 + +if interactive: + # Add scatter plot + fig.add_trace(go.Scatter(x=list(range(len(qvals))), + y=-np.array(qvals), + mode='markers', + marker=dict(color=cs, size=2), + hovertext=[f'MAG: {rep_list[i]}' for i in range(len(qvals))], + hoverinfo="text")) + # Add a horizontal line to represent the q-value threshold + fig.add_shape( + type="line", + x0=0, + x1=len(qvals), + y0=-np.log10(LIMIT), + y1=-np.log10(LIMIT), + line=dict(color="Red", width=1) + ) + + + # Add axis labels + fig.update_layout(xaxis_title="UHGG genomes coloured and clustered by species", + yaxis_title="-log10(p-val)") + + fig.show() + exit() + plt.scatter(range(len(qvals)), -np.array(qvals),c = cs,s = size) ax.spines[['right', 'top']].set_visible(False) plt.xticks([]) -plt.axhline(-np.log10(q)) +plt.axhline(-np.log10(LIMIT)) #plt.scatter(range(len(qvals)), fd,c = c,s = 1) -plt.ylabel("-log10(q-val)") -plt.xlabel("UHGG MAGs coloured and clustered by species") +plt.ylabel("-log10(p-val)") +plt.xlabel("UHGG genomes coloured and clustered by species") plt.savefig('figures/manhat.png', dpi = 300) plt.show() cmap = plt.get_cmap('tab20') plt.set_cmap(cmap) -fig, ax = plt.subplots(figsize = (6.5* cm , 6.5 * cm)) -plt.ylabel("-log10(q-val)") -plt.xlabel("Agathobacter rectalis MAGs sorted by similarity") +fig, ax = plt.subplots(figsize = (5.5* cm , 4.5 * cm)) +plt.ylabel("-log10(p-val)") +plt.xlabel("A. rectalis genomes ordered by similarity") mag_c = mag_to_c[agatho_rep] #print(mag_c) it_tab20 = [plt.cm.tab20(i) for i in range(20)] plt.scatter(range(len(agatho_qvals)), -np.array(agatho_qvals), c = [it_tab20[mag_c] for x in agatho_qvals], s = size, cmap = cmap) plt.xticks([]) ax.spines[['right', 'top']].set_visible(False) -plt.axhline(-np.log10(q)) +plt.axhline(-np.log10(LIMIT)) plt.savefig("figures/agatho.png", dpi = 300) plt.show() + + +print(len([x for x in agatho_qvals if x < np.log10(LIMIT)]), len(agatho_qvals)) diff --git a/scripts/mock_community_plot.py b/scripts/mock_community_plot.py index 894d598a..9c8e0b9a 100644 --- a/scripts/mock_community_plot.py +++ b/scripts/mock_community_plot.py @@ -115,36 +115,41 @@ class result: res_anis = [[x.final_ani for x in res] for res in results] res_anis_low = [[x.final_ani for x in res if x.low] for res in results] res_anis_pass = [[x.adj_ani for x in res if not x.low] for res in results] +res_naive = [[x.naive_ani for x in res] for res in results] boxes = [] all_res = [] s = 7 offset = 0.0 -width = 0.6 +width = 0.7 labels = [] -num_methods = 3 +num_methods = 4 for (i,x) in enumerate(['Illumina', 'Nanopore-old', 'PacBio']): ax[i].set_title(x, fontsize = plt.rcParams['font.size']) positions = [] positions.append(i*num_methods - 1) - positions.append(i*num_methods + 0 - offset) - positions.append(i*num_methods + 1 - 2 * offset) + positions.append(i*num_methods + 0) + positions.append(i*num_methods + 1) + positions.append(i*num_methods + 2) boxes = [] box_c100 = res_anis[i] + box_naive = res_naive[i] box_mash = mash_results[i] box_sour = sour_results[i] boxes.append(box_c100) + boxes.append(box_naive) boxes.append(box_mash) boxes.append(box_sour) pos_c100_low = rand_jitter([positions[0] for x in range(len(res_anis_low[i]))]) pos_c100_pass = rand_jitter([positions[0] for x in range(len(res_anis_pass[i]))]) - pos_mash = rand_jitter([positions[1] for x in range(len(box_mash))]) - pos_sour = rand_jitter([positions[2] for x in range(len(box_sour))]) + pos_naive = rand_jitter([positions[1] for x in range(len(box_naive))]) + pos_mash = rand_jitter([positions[2] for x in range(len(box_mash))]) + pos_sour = rand_jitter([positions[3] for x in range(len(box_sour))]) dot_label = [] for box in boxes: @@ -167,6 +172,7 @@ class result: ax[i].scatter(pos_c100_pass,res_anis_pass[i], s = s, color= cmap[0]) # ax[i].set_ylim([70,100]) + ax[i].scatter( pos_naive,box_naive,s = s, color = cmap[3]) #ax[i].scatter( pos_c1000_low,res_anis_low[2*i+1], s = s, color = 'black', marker = "s") #ax[i].scatter( pos_c1000_pass, res_anis_pass[2*i+1],s = s, color = cmap[4], label=dot_label[1] + ' sylph -c1000') #ax[i].scatter( pos_c1000_pass, res_anis_pass[2*i+1],s = s, color = cmap[4]) @@ -182,8 +188,9 @@ class result: #labels.append("sylph query\n\n" + dot_label[0]) #labels.append("mash screen\n\n" + dot_label[1]) #labels.append("sourmash\n\n" + dot_label[2]) - labels.append("sylph query") - labels.append("mash screen") + labels.append("sylph") + labels.append("Naive ANI") + labels.append("Mash\nscreen") labels.append("sourmash") bp = ax[i].boxplot(boxes, showfliers=False, positions = positions, widths = width, labels=labels) diff --git a/scripts/synthetic_pois_plot.py b/scripts/synthetic_pois_plot.py index 60097c84..6a9cabc9 100644 --- a/scripts/synthetic_pois_plot.py +++ b/scripts/synthetic_pois_plot.py @@ -9,7 +9,7 @@ import glob from natsort import natsorted -cov_plot = False +cov_plot = True np.random.seed(0) def rand_jitter(arr): stdev = 0.00