13
13
from collections import defaultdict , Counter
14
14
from itertools import chain , combinations
15
15
import pickle
16
+ import json
16
17
17
18
import numpy
18
19
import pylab
28
29
import pandas as pd
29
30
import plotly .graph_objects as go
30
31
import squarify
32
+ import taxburst
31
33
32
34
# this turns off a warning in presence_filter, but results in an error in
33
35
# upsetplot :sweat_smile:
38
40
39
41
import sourmash
40
42
from sourmash import sourmash_args
41
- from sourmash .logging import debug_literal , error , notify
43
+ from sourmash .logging import debug_literal , error , notify , print_results
42
44
from sourmash .plugins import CommandLinePlugin
43
45
import sourmash_utils
44
46
from sourmash .cli .utils import (add_ksize_arg , add_moltype_args , add_scaled_arg )
@@ -211,6 +213,21 @@ def sample_d_to_idents(sample_d):
211
213
212
214
return xx
213
215
216
+ def load_taxburst_json (filename , * , normalize_counts = True ):
217
+ """
218
+ Load in JSON format output by taxburst.
219
+
220
+ Optionally normalize counts to fractions that sum to 1.
221
+ """
222
+ with open (filename , 'rt' ) as fp :
223
+ top_nodes = json .load (fp )
224
+
225
+ if normalize_counts :
226
+ taxburst .tree_utils .normalize_tree_counts (top_nodes )
227
+
228
+ return top_nodes
229
+
230
+
214
231
#
215
232
# CLI plugin code
216
233
#
@@ -1765,6 +1782,7 @@ def save_sankey_diagram(fig, output_file):
1765
1782
else :
1766
1783
fig .show () # Show the plot if no output file is specified
1767
1784
1785
+
1768
1786
def load_lingroups (map_csv ):
1769
1787
"""Return {full_lineage_string: human_name}."""
1770
1788
lin2name = {}
@@ -1774,6 +1792,7 @@ def load_lingroups(map_csv):
1774
1792
notify (f"loaded { len (lin2name )} lingroup names from '{ map_csv } '" )
1775
1793
return lin2name
1776
1794
1795
+
1777
1796
def expand_with_ancestors_sum (rows , fraction_col ):
1778
1797
"""Expand rows with all ancestor paths, summing children if ancestor missing."""
1779
1798
lineage_fracs = {row ["lineage" ].strip (): float (row [fraction_col ])
@@ -2038,6 +2057,47 @@ def process_csv_for_sankey(input_csv, csv_type, lingroup_map=None):
2038
2057
2039
2058
return nodes , links , hover_texts
2040
2059
2060
+
2061
+ def process_taxburst_for_sankey (input_file ):
2062
+ nodes = [] # List of unique taxonomy nodes
2063
+ node_map = {} # Map taxonomic label to index
2064
+ links = [] # List of link connections with flow values
2065
+ hover_texts = [] # Custom hover text for percentages
2066
+ processed_lineages = set () # Tracks added lineage links
2067
+
2068
+ top_nodes = load_taxburst_json (input_file )
2069
+ all_nodes = taxburst .tree_utils .collect_all_nodes (top_nodes )
2070
+
2071
+ # Process each row in the dataset
2072
+ for n , node in enumerate (all_nodes ):
2073
+ source_label = node ["name" ]
2074
+
2075
+ if source_label not in node_map :
2076
+ node_map [source_label ] = len (nodes )
2077
+ nodes .append (source_label )
2078
+
2079
+ # Iterate through children
2080
+ for child_node in node .get ("children" , []):
2081
+ percent = float (child_node ["count" ]) * 100
2082
+ target_label = child_node ["name" ]
2083
+
2084
+ # Assign indices to nodes
2085
+ if target_label not in node_map :
2086
+ node_map [target_label ] = len (nodes )
2087
+ nodes .append (target_label )
2088
+
2089
+ # Create a link between source and target
2090
+ links .append ({
2091
+ "source" : node_map [source_label ],
2092
+ "target" : node_map [target_label ],
2093
+ "value" : percent
2094
+ })
2095
+ hover_texts .append (f"{ source_label } → { target_label } <br>{ percent :.2f} %" )
2096
+ notify (f"loaded { n + 1 } nodes from '{ input_file } '" )
2097
+
2098
+ return nodes , links , hover_texts
2099
+
2100
+
2041
2101
class Command_Sankey (CommandLinePlugin ):
2042
2102
command = 'sankey'
2043
2103
description = """\
@@ -2056,6 +2116,7 @@ def __init__(self, subparser):
2056
2116
group = subparser .add_mutually_exclusive_group (required = True )
2057
2117
group .add_argument ("--summary-csv" , type = str , help = "Path to csv_summary generated by running 'sourmash tax metagenome' on a sourmash gather csv" )
2058
2118
group .add_argument ("--annotate-csv" , type = str , help = "Path to 'with-lineages' file generated by running 'sourmash tax annotate' on a sourmash gather csv" )
2119
+ group .add_argument ('--taxburst-json' , type = str , help = "taxburst JSON output" )
2059
2120
subparser .add_argument ("--lingroups" , type = str , help = "Path to 'lingroups' file (lineage to lingroup mapping) to enable lingroup labeling in the Sankey diagram. Not needed if `csv_summary` was generated with `--lingroups` file provided." )
2060
2121
2061
2122
subparser .add_argument ("-o" , "--output" , type = str , help = "output file for alluvial flow diagram" )
@@ -2064,25 +2125,31 @@ def __init__(self, subparser):
2064
2125
subparser .epilog = "You must provide either --summary-csv or --annotate-csv, but not both."
2065
2126
2066
2127
def main (self , args ):
2067
- # Build info appropriately based on input file type
2068
- if args .summary_csv :
2069
- input_csv = args .summary_csv
2070
- csv_type = "csv_summary"
2071
- required_headers = ["f_weighted_at_rank" , "lineage" ]
2128
+ if args .summary_csv or args .annotate_csv :
2129
+ # Build info appropriately based on input file type
2130
+ if args .summary_csv :
2131
+ input_csv = args .summary_csv
2132
+ csv_type = "csv_summary"
2133
+ required_headers = ["f_weighted_at_rank" , "lineage" ]
2134
+ else :
2135
+ input_csv = args .annotate_csv
2136
+ csv_type = "with-lineages"
2137
+ required_headers = ["f_unique_weighted" , "lineage" ]
2138
+
2139
+ # Check if the required headers are present
2140
+ with open (input_csv , 'r' ) as file :
2141
+ reader = csv .DictReader (file )
2142
+ if not all (header in reader .fieldnames for header in required_headers ):
2143
+ raise ValueError (f"Expected headers { required_headers } not found. Is this a correct file for '{ csv_type } ' type?" )
2144
+
2145
+ # process csv
2146
+ nodes , links , hover_texts = process_csv_for_sankey (input_csv , csv_type , lingroup_map = args .lingroups )
2147
+ base_title = os .path .basename (input_csv .rsplit (".csv" )[0 ])
2148
+ elif args .taxburst_json :
2149
+ nodes , links , hover_texts = process_taxburst_for_sankey (args .taxburst_json )
2150
+ base_title = os .path .basename (args .taxburst_json .rsplit (".json" )[0 ])
2072
2151
else :
2073
- input_csv = args .annotate_csv
2074
- csv_type = "with-lineages"
2075
- required_headers = ["f_unique_weighted" , "lineage" ]
2076
-
2077
- # Check if the required headers are present
2078
- with open (input_csv , 'r' ) as file :
2079
- reader = csv .DictReader (file )
2080
- if not all (header in reader .fieldnames for header in required_headers ):
2081
- raise ValueError (f"Expected headers { required_headers } not found. Is this a correct file for '{ csv_type } ' type?" )
2082
-
2083
- # process csv
2084
- nodes , links , hover_texts = process_csv_for_sankey (input_csv , csv_type , lingroup_map = args .lingroups )
2085
- base_title = os .path .basename (input_csv .rsplit (".csv" )[0 ])
2152
+ assert 0 , "unhandled input format"
2086
2153
2087
2154
# Create Sankey diagram
2088
2155
fig = go .Figure (go .Sankey (
@@ -2283,14 +2350,17 @@ class Command_TreeMap(CommandLinePlugin):
2283
2350
2284
2351
def __init__ (self , subparser ):
2285
2352
super ().__init__ (subparser )
2286
- subparser .add_argument ('csvfile ' , help = 'csv_summary output from tax metagenome' )
2353
+ subparser .add_argument ('inputfile ' , help = 'input taxonomy - by default, csv_summary output from tax metagenome' )
2287
2354
subparser .add_argument ('-o' , '--output' , required = True ,
2288
2355
help = 'output figure to this file' )
2289
2356
subparser .add_argument ('-r' , '--rank' , default = 'phylum' ,
2290
2357
help = 'display at this rank' )
2291
2358
subparser .add_argument ('-n' , '--num-to-display' , type = int ,
2292
2359
default = 25 ,
2293
2360
help = "display at most these many taxa; aggregate the remainder (default: 25; 0 to display all)" )
2361
+ subparser .add_argument ('--taxburst-json' ,
2362
+ action = 'store_true' ,
2363
+ help = 'input format is JSON from taxburst' )
2294
2364
2295
2365
2296
2366
def main (self , args ):
@@ -2303,24 +2373,39 @@ def plot_treemap(args):
2303
2373
import itertools
2304
2374
cmap = colormaps ['viridis' ]
2305
2375
2306
- df = pd .read_csv (args .csvfile )
2376
+ if not args .taxburst_json :
2377
+ df = pd .read_csv (args .inputfile )
2307
2378
2308
- print (f"reading input file '{ args .csvfile } '" )
2309
- for colname in ('query_name' , 'rank' , 'f_weighted_at_rank' , 'lineage' ):
2310
- if colname not in df .columns :
2311
- print (f"input is missing column '{ colname } '; is this a csv_summary file?" )
2312
- sys .exit (- 1 )
2379
+ print (f"reading input file '{ args .inputfile } '" )
2380
+ for colname in ('query_name' , 'rank' , 'f_weighted_at_rank' , 'lineage' ):
2381
+ if colname not in df .columns :
2382
+ print (f"input is missing column '{ colname } '; is this a csv_summary file?" )
2383
+ sys .exit (- 1 )
2313
2384
2314
- df = df .sort_values (by = 'f_weighted_at_rank' )
2385
+ df = df .sort_values (by = 'f_weighted_at_rank' )
2315
2386
2316
- # select rank
2317
- df2 = df [df ['rank' ] == args .rank ]
2318
- df2 ['name' ] = df2 ['lineage' ].apply (lambda x : x .split (';' )[- 1 ])
2387
+ # select rank
2388
+ df2 = df [df ['rank' ] == args .rank ]
2389
+ df2 ['name' ] = df2 ['lineage' ].apply (lambda x : x .split (';' )[- 1 ])
2319
2390
2320
- fractions = list (df2 ['f_weighted_at_rank' ].tolist ())
2321
- names = list (df2 ['name' ].tolist ())
2322
- fractions .reverse ()
2323
- names .reverse ()
2391
+ fractions = list (df2 ['f_weighted_at_rank' ].tolist ())
2392
+ names = list (df2 ['name' ].tolist ())
2393
+ fractions .reverse ()
2394
+ names .reverse ()
2395
+ else :
2396
+ assert args .taxburst_json
2397
+ top_nodes = load_taxburst_json (args .inputfile )
2398
+
2399
+ all_nodes = taxburst .tree_utils .collect_all_nodes (top_nodes )
2400
+ all_nodes = [ n for n in all_nodes if n ["rank" ] == args .rank ]
2401
+ unclass = [ n for n in top_nodes if n ["name" ] == "unclassified" ]
2402
+ if unclass :
2403
+ assert len (unclass ) == 1
2404
+ all_nodes .append (unclass [0 ])
2405
+
2406
+ all_nodes .sort (key = lambda n : - n ["count" ])
2407
+ fractions = [ n ["count" ] for n in all_nodes ]
2408
+ names = [ n ["name" ] for n in all_nodes ]
2324
2409
2325
2410
num = max (args .num_to_display , 0 ) # non-negative
2326
2411
num = min (args .num_to_display , len (names )) # list of names
0 commit comments