Skip to content

Commit

Permalink
script add
Browse files Browse the repository at this point in the history
  • Loading branch information
bluenote-1577 committed Nov 12, 2023
1 parent 1693b6f commit d6524e6
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 47 deletions.
55 changes: 39 additions & 16 deletions scripts/chng_between_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,13 @@ class result:
continue
pair = natsorted([srr,read_file_to_pair_dict[srr]])
if 'restrict' in res.contig_name:
pair_to_res[tuple(pair)].append((res.final_ani))
pair_to_res_naive[tuple(pair)].append(res.naive_ani)
if boxplot_covs:
pair_to_res[tuple(pair)].append((res.eff_cov))
pair_to_res_naive[tuple(pair)].append(res.eff_cov)

else:
pair_to_res[tuple(pair)].append((res.final_ani))
pair_to_res_naive[tuple(pair)].append(res.naive_ani)
if tuple(pair) in res_used_pairs:
continue
if read_file_to_status[srr] == 'Case':
Expand All @@ -150,21 +155,26 @@ class result:
case_res_n.append(res.naive_ani)

else:
case_res.append(res.eff_cov)
case_res_n.append(res.eff_cov)
case_res.append(res.median_cov)
case_res_n.append(res.median_cov)
else:
if not boxplot_covs:
control_res.append(res.final_ani)
control_res_n.append(res.naive_ani)

else:
control_res.append(res.eff_cov)
control_res_n.append(res.eff_cov)
control_res.append(res.median_cov)
control_res_n.append(res.median_cov)

res_used_pairs.add(tuple(pair))
if 'globo' in res.contig_name:
pair_to_res_globo[tuple(pair)].append((res.final_ani))
pair_to_res_naive_globo[tuple(pair)].append(res.naive_ani)
if boxplot_covs:
pair_to_res_globo[tuple(pair)].append((res.eff_cov))
pair_to_res_naive_globo[tuple(pair)].append(res.eff_cov)

else:
pair_to_res_globo[tuple(pair)].append((res.final_ani))
pair_to_res_naive_globo[tuple(pair)].append(res.naive_ani)

if tuple(pair) in globo_used_pairs:
continue
Expand All @@ -174,23 +184,24 @@ class result:
case_globo_n.append(res.naive_ani)

else:
case_globo.append(res.eff_cov)
case_globo_n.append(res.eff_cov)
case_globo.append(res.median_cov)
case_globo_n.append(res.median_cov)

else:
if not boxplot_covs:
control_globo.append(res.final_ani)
control_globo_n.append(res.naive_ani)
else:
control_globo.append(res.eff_cov)
control_globo_n.append(res.eff_cov)
control_globo.append(res.median_cov)
control_globo_n.append(res.median_cov)

globo_used_pairs.add(tuple(pair))

print(case_res)
print(len(globo_used_pairs))
#print(len(globo_used_pairs))
print(pair_to_res)
print(pair_to_res_globo)

sc = np.array(list(pair_to_res.values()))
sc2 = np.array(list(pair_to_res_naive.values()))
Expand Down Expand Up @@ -229,10 +240,10 @@ class result:
ax[1][0].text(.05, .99, 'M. globosa', ha='left', va='top', transform=ax[1][0].transAxes)


ax[0][0].text(.55, .25, rf"$R^2$ = {round(sr.rvalue**2,3)}", ha='left', va='top', transform=ax[0][0].transAxes)
ax[0][1].text(.55, .25, rf"$R^2$ = {round(nr.rvalue**2,3)}", ha='left', va='top', transform=ax[0][1].transAxes)
ax[1][1].text(.55, .25, rf"$R^2$ = {round(ng.rvalue**2,3)}", ha='left', va='top', transform=ax[1][1].transAxes)
ax[1][0].text(.55, .25, rf"$R^2$ = {round(sg.rvalue**2,3)}", ha='left', va='top', transform=ax[1][0].transAxes)
ax[0][0].text(.55, .25, rf"$R$ = {round(sr.rvalue**1,3)}", ha='left', va='top', transform=ax[0][0].transAxes)
ax[0][1].text(.55, .25, rf"$R$ = {round(nr.rvalue**1,3)}", ha='left', va='top', transform=ax[0][1].transAxes)
ax[1][1].text(.55, .25, rf"$R$ = {round(ng.rvalue**1,3)}", ha='left', va='top', transform=ax[1][1].transAxes)
ax[1][0].text(.55, .25, rf"$R$ = {round(sg.rvalue**1,3)}", ha='left', va='top', transform=ax[1][0].transAxes)

for a in ax:
for b in a:
Expand Down Expand Up @@ -307,6 +318,17 @@ def add_stat_annotation(ax, bp, pval):
control_res_n = [x[0]/2 + x[1]/2 for (a,x) in pair_to_res_naive.items() if read_file_to_status[a[0]] == 'Control']
control_globo_n = [x[0]/2 + x[1]/2 for (a,x) in pair_to_res_naive_globo.items() if read_file_to_status[a[0]] == 'Control']

print(len([x for x in case_res if x > 95]), ': num restrica > 95 case')
print(len([x for x in control_res if x > 95]), ': num restrica > 95 cont')
print(len([x for x in case_globo if x > 95]), ': num globo > 95 case')
print(len([x for x in control_globo if x > 95]), ': num globo > 95 cont')

print(len([x for x in case_res if x > 90]), ': num restrica > 90 case')
print(len([x for x in control_res if x > 90]), ': num restrica > 90 cont')
print(len([x for x in case_globo if x > 90]), ': num globo > 90 case')
print(len([x for x in control_globo if x > 90]), ': num globo > 90 cont')



# Add the boxplots to the axes
bp1 = ax[0][0].boxplot([case_res, control_res], patch_artist=True, boxprops=dict(facecolor=cmap[0]), labels = ['Case', 'Control'])
Expand Down Expand Up @@ -359,6 +381,7 @@ def add_stat_annotation(ax, bp, pval):
for bp in bps:
for median in bp['medians']:
median.set_color('black')
print(median.get_ydata())

for a in ax:
for b in a:
Expand Down
14 changes: 7 additions & 7 deletions scripts/diagonal_ani_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass
from natsort import natsorted
cmap = sns.color_palette("muted")
plt_diag = True
plt_diag = False

np.random.seed(0)
def rand_jitter(arr):
Expand Down Expand Up @@ -137,8 +137,8 @@ class result:
true_results[-1].append(res)


fig, ax = plt.subplots(1, 3, figsize = (16* cm , 7 * cm), sharey = True, sharex = True)
s = 8
fig, ax = plt.subplots(1, 3, figsize = (16* cm , 5.5 * cm), sharey = True, sharex = True)
s = 7

for (i,name) in enumerate(['Illumina', 'Nanopore-old', 'PacBio']):
for j in range(1):
Expand Down Expand Up @@ -206,7 +206,7 @@ class result:

ax[i].set_xlabel("True containment ANI")
if j == 1:
ax[i].scatter(x,y,s = s, color = cmap[2], alpha = 0.5, label = rf"c = 1000, $R^2$ = {round(good_lr.rvalue**2,3)}")
ax[i].scatter(x,y,s = s, color = cmap[2], alpha = 0.5, label = rf"c = 1000, $R$ = {round(good_lr.rvalue**1,3)}")
else:
if i == 0:
ax[i].set_title("Illumina", fontsize = plt.rcParams['font.size'])
Expand All @@ -215,11 +215,11 @@ class result:
elif i == 2:
ax[i].set_title("PacBio", fontsize = plt.rcParams['font.size'])

ax[i].scatter(x,y,s = s, color = cmap[0], alpha = 0.5, label = rf"sylph, $R^2$ = {round(good_lr.rvalue**2,3)}")
ax[i].scatter(x,z, s= s, color = cmap[3], alpha = 0.5, label = rf"Naive, $R^2$ = {round(naive_lr.rvalue**2,3)}")
ax[i].scatter(x,y,s = s, color = cmap[0], alpha = 0.5, label = rf"sylph $R$ = {round(good_lr.rvalue**1,2)}")
ax[i].scatter(x,z, s= s, color = cmap[3], alpha = 0.5, label = rf"Naive $R$ = {round(naive_lr.rvalue**1,2)}")
ax[i].plot([90,100],[90,100],'--', c = 'black')
print('covered: ' + str(covered/total) + 'total: ' + str(total))
ax[i].set_ylim([85,100])
ax[i].set_ylim([85,102])
if i == 0:
ax[i].set_ylabel("Estimated ANI")

Expand Down
90 changes: 75 additions & 15 deletions scripts/manhat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,28 @@
import scipy.stats as stats
import statsmodels.api as sm
import statsmodels.stats as ss
import plotly.express as px
import plotly.graph_objects as go

from scipy.stats.distributions import norm,uniform
interactive = False

import matplotlib.cm as cm
import numpy as np
cmap = plt.get_cmap('tab20')
plt.set_cmap(cmap)
q = 0.05
LIMIT = 0.0018366409299248078

def fdr(p_vals, alpha ):
not_used, pvals = ss.multitest.fdrcorrection(p_vals, alpha = alpha)
limit = 100
for i in range(len(not_used)):
if not_used[i] == False:
if limit > p_vals[i]:
limit = p_vals[i]
continue
print("LIMIT OF DETECT", limit)
return pvals


Expand All @@ -21,7 +32,7 @@ def fdr(p_vals, alpha ):
plt.rcParams.update({'font.size': 7})
plt.rcParams.update({'figure.autolayout': True})
plt.rcParams.update({'font.family':'arial'})
fig, ax = plt.subplots(figsize = (6.5* cm , 6.5 * cm))
fig, ax = plt.subplots(figsize = (5.5* cm , 4.5 * cm))

pvals = []
c = []
Expand Down Expand Up @@ -88,7 +99,7 @@ def fdr(p_vals, alpha ):
print(gn_to_rep[s_and_mag_pair[i][1]])
#exit()

ax.scatter(-np.log10(qq[0][0]), -np.log10(qq[0][1]), s = 4, label = 'Arbitrary species representative')
ax.scatter(-np.log10(qq[0][0]), -np.log10(qq[0][1]), s = 4, label = 'Species\nrepresentative')
#ax.plot([np.min(-np.log10(qq[0][0])), np.max(-np.log10(qq[0][0]))], [np.min(-np.log10(qq[0][1])), np.max(-np.log10(qq[0][1]))], 'r-')
ax.plot([0,5],[0,5], 'r-')
#plt.title('Q-Q plot of -log10(p-values) against uniform distribution')
Expand All @@ -113,25 +124,44 @@ def fdr(p_vals, alpha ):

cs = []
mag_pval_list = []
rep_list = []
seen_cs = set()
for line in open(order_file,'r'):
mag = line.split('.fa')[0].rstrip()
if mag in mag_to_pval:
pvals.append(mag_to_pval[mag])
seen_cs.add(mag_to_c[mag])
cs.append(len(seen_cs)%20)
mag_pval_list.append(mag)
if interactive:
if hash(mag) % 10 == 0:
if mag in mag_to_pval:
pvals.append(mag_to_pval[mag])
seen_cs.add(mag_to_c[mag])
cs.append(len(seen_cs)%20)
mag_pval_list.append(mag)
rep_list.append(gn_to_rep[mag])
else:
pvals.append(mag_to_pval[mag])
seen_cs.add(mag_to_c[mag])
cs.append(len(seen_cs)%20)
mag_pval_list.append(mag)


fig, ax = plt.subplots(figsize = (16* cm , 6 * cm))

# Instead of creating the scatter plot using Matplotlib, use Plotly
if interactive:
fig = go.Figure()

agatho_qvals = []
agatho_rep = 'MGYG000002492'
#print(pvals)
qvals = fdr(np.power(10,pvals), q)
qvals = np.log10(qvals)
#I misunderstood qvals, it's a diff procedure and not benjamini hochbergq
#qvals = np.power(10,pvals)
#qvals = np.log10(qvals)
qvals = pvals
seen_reps = set()
for (i,pval) in enumerate(qvals):
m = mag_pval_list[i]
if pval < np.log10(q):
if pval < np.log10(LIMIT):
rep = gn_to_rep[m]
if rep not in seen_reps:
s = f"{10**pvals[i]},{10**pval},{rep},{gn_to_mag_rep[m]},{m},{mag_to_effect[mag_pval_list[i]]}"
Expand All @@ -144,29 +174,59 @@ def fdr(p_vals, alpha ):
size = 3
else:
size = 0.3

if interactive:
# Add scatter plot
fig.add_trace(go.Scatter(x=list(range(len(qvals))),
y=-np.array(qvals),
mode='markers',
marker=dict(color=cs, size=2),
hovertext=[f'MAG: {rep_list[i]}' for i in range(len(qvals))],
hoverinfo="text"))
# Add a horizontal line to represent the q-value threshold
fig.add_shape(
type="line",
x0=0,
x1=len(qvals),
y0=-np.log10(LIMIT),
y1=-np.log10(LIMIT),
line=dict(color="Red", width=1)
)


# Add axis labels
fig.update_layout(xaxis_title="UHGG genomes coloured and clustered by species",
yaxis_title="-log10(p-val)")

fig.show()
exit()

plt.scatter(range(len(qvals)), -np.array(qvals),c = cs,s = size)
ax.spines[['right', 'top']].set_visible(False)
plt.xticks([])
plt.axhline(-np.log10(q))
plt.axhline(-np.log10(LIMIT))

#plt.scatter(range(len(qvals)), fd,c = c,s = 1)
plt.ylabel("-log10(q-val)")
plt.xlabel("UHGG MAGs coloured and clustered by species")
plt.ylabel("-log10(p-val)")
plt.xlabel("UHGG genomes coloured and clustered by species")
plt.savefig('figures/manhat.png', dpi = 300)
plt.show()

cmap = plt.get_cmap('tab20')
plt.set_cmap(cmap)

fig, ax = plt.subplots(figsize = (6.5* cm , 6.5 * cm))
plt.ylabel("-log10(q-val)")
plt.xlabel("Agathobacter rectalis MAGs sorted by similarity")
fig, ax = plt.subplots(figsize = (5.5* cm , 4.5 * cm))
plt.ylabel("-log10(p-val)")
plt.xlabel("A. rectalis genomes ordered by similarity")
mag_c = mag_to_c[agatho_rep]
#print(mag_c)
it_tab20 = [plt.cm.tab20(i) for i in range(20)]
plt.scatter(range(len(agatho_qvals)), -np.array(agatho_qvals), c = [it_tab20[mag_c] for x in agatho_qvals], s = size, cmap = cmap)
plt.xticks([])
ax.spines[['right', 'top']].set_visible(False)
plt.axhline(-np.log10(q))
plt.axhline(-np.log10(LIMIT))
plt.savefig("figures/agatho.png", dpi = 300)
plt.show()


print(len([x for x in agatho_qvals if x < np.log10(LIMIT)]), len(agatho_qvals))
23 changes: 15 additions & 8 deletions scripts/mock_community_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,36 +115,41 @@ class result:
res_anis = [[x.final_ani for x in res] for res in results]
res_anis_low = [[x.final_ani for x in res if x.low] for res in results]
res_anis_pass = [[x.adj_ani for x in res if not x.low] for res in results]
res_naive = [[x.naive_ani for x in res] for res in results]
boxes = []
all_res = []
s = 7

offset = 0.0
width = 0.6
width = 0.7
labels = []
num_methods = 3
num_methods = 4


for (i,x) in enumerate(['Illumina', 'Nanopore-old', 'PacBio']):
ax[i].set_title(x, fontsize = plt.rcParams['font.size'])
positions = []
positions.append(i*num_methods - 1)
positions.append(i*num_methods + 0 - offset)
positions.append(i*num_methods + 1 - 2 * offset)
positions.append(i*num_methods + 0)
positions.append(i*num_methods + 1)
positions.append(i*num_methods + 2)

boxes = []
box_c100 = res_anis[i]
box_naive = res_naive[i]
box_mash = mash_results[i]
box_sour = sour_results[i]

boxes.append(box_c100)
boxes.append(box_naive)
boxes.append(box_mash)
boxes.append(box_sour)

pos_c100_low = rand_jitter([positions[0] for x in range(len(res_anis_low[i]))])
pos_c100_pass = rand_jitter([positions[0] for x in range(len(res_anis_pass[i]))])
pos_mash = rand_jitter([positions[1] for x in range(len(box_mash))])
pos_sour = rand_jitter([positions[2] for x in range(len(box_sour))])
pos_naive = rand_jitter([positions[1] for x in range(len(box_naive))])
pos_mash = rand_jitter([positions[2] for x in range(len(box_mash))])
pos_sour = rand_jitter([positions[3] for x in range(len(box_sour))])

dot_label = []
for box in boxes:
Expand All @@ -167,6 +172,7 @@ class result:
ax[i].scatter(pos_c100_pass,res_anis_pass[i], s = s, color= cmap[0])
# ax[i].set_ylim([70,100])

ax[i].scatter( pos_naive,box_naive,s = s, color = cmap[3])
#ax[i].scatter( pos_c1000_low,res_anis_low[2*i+1], s = s, color = 'black', marker = "s")
#ax[i].scatter( pos_c1000_pass, res_anis_pass[2*i+1],s = s, color = cmap[4], label=dot_label[1] + ' sylph -c1000')
#ax[i].scatter( pos_c1000_pass, res_anis_pass[2*i+1],s = s, color = cmap[4])
Expand All @@ -182,8 +188,9 @@ class result:
#labels.append("sylph query\n\n" + dot_label[0])
#labels.append("mash screen\n\n" + dot_label[1])
#labels.append("sourmash\n\n" + dot_label[2])
labels.append("sylph query")
labels.append("mash screen")
labels.append("sylph")
labels.append("Naive ANI")
labels.append("Mash\nscreen")
labels.append("sourmash")

bp = ax[i].boxplot(boxes, showfliers=False, positions = positions, widths = width, labels=labels)
Expand Down
Loading

0 comments on commit d6524e6

Please sign in to comment.