Skip to content

Commit 649a5dc

Browse files
authored
MRG: read taxburst JSON for sankey & treemap (#87)
* WIP: read taxburst JSON for sankey & treemap * add json loading to sankey * note * normalize; some debug * fix multiple links stuff in tax_annotate output * refactor * add taxburst dep * cleanup * add taxburst examples to test workflow * update taxburst * comment; cleanup * bump version * attempt remaining part of manual merge * try manual merge round 2 * fix Snakefile removal
1 parent 9ca2905 commit 649a5dc

File tree

5 files changed

+161
-37
lines changed

5 files changed

+161
-37
lines changed

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ categories. It also includes support for sparse comparison output
1111
formats produced by the fast multithreaded `manysearch` and `pairwise`
1212
functions in the
1313
[branchwater plugin for sourmash](https://github.com/sourmash-bio/sourmash_plugin_branchwater).
14-
Finally, it includes a sankey/alluvial flow plot to visualize metagenomic profiling from the `sourmash gather` to `sourmash tax` workflow.
14+
Finally, it includes a sankey/alluvial flow plot and a treemap plot to
15+
visualize metagenomic profiling from the `sourmash gather` to
16+
`sourmash tax` workflow.
1517

1618
## Why does this plugin exist?
1719

@@ -437,6 +439,10 @@ produces:
437439

438440
By default, we will open an interactive `html` file. To output to a file, specify the file name with `-o` and use your desired filetype extension (.html, .png, .jpg, .jpeg, .pdf, or .svg). To specify the title, use `--title`.
439441

442+
The `sankey` command also supports ingest of
443+
[taxburst's JSON format](https://taxburst.github.io/taxburst/command-line/#outputting-json-format),
444+
which allows `sankey` to be used with SingleM and Krona formats, among
445+
others.
440446

441447
### `tree` - plot Neighbor-Joining tree
442448

@@ -482,6 +488,11 @@ produces:
482488

483489
![treemap visualization](examples/tax-mg.treemap.png)
484490

491+
The `treemap` command also supports ingest of
492+
[taxburst's JSON format](https://taxburst.github.io/taxburst/command-line/#outputting-json-format),
493+
which allows `treemap` to be used with SingleM and Krona formats, among
494+
others.
495+
485496
### `presence_filter` - plot presence/abundance scatterplot of genomes detected by gather
486497

487498
It is sometimes interesting to look at the distribution of size and abundance

examples/Snakefile

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,20 @@ rule all:
1919
"weighted_venn.png",
2020
"tax-mg.sankey.png",
2121
"tax-annot.sankey.png",
22+
"taxburst.sankey.html",
2223
"disttree10sketches.matrix.png",
2324
"disttree10sketches.pairwise.png",
2425
"tax-mg.treemap.png",
26+
"taxburst.treemap.png",
2527
"presence_filter.png",
2628

29+
rule tax:
30+
input:
31+
"tax-mg.sankey.png",
32+
"tax-mg.treemap.png",
33+
"taxburst.treemap.png",
34+
"taxburst.sankey.html",
35+
2736
rule make_10sketches:
2837
input:
2938
expand("sketches/{n}.sig.zip", n=sketches_10)
@@ -312,6 +321,15 @@ rule treemap_mgx_summary:
312321
sourmash scripts treemap {input} -o {output}
313322
"""
314323

324+
rule taxburst_treemap:
325+
input:
326+
"tax/SRR11125891.t0.lineages.json",
327+
output:
328+
"taxburst.treemap.png",
329+
shell: """
330+
sourmash scripts treemap --taxburst-json {input} -o {output}
331+
"""
332+
315333
rule sankey_mgx_annotate:
316334
input:
317335
"tax/test.gather.with-lineages.csv"
@@ -322,6 +340,15 @@ rule sankey_mgx_annotate:
322340
sourmash scripts sankey --annotate-csv {input} -o {output}
323341
"""
324342

343+
rule sankey_taxburst:
344+
input:
345+
"tax/SRR11125891.t0.lineages.json",
346+
output:
347+
"taxburst.sankey.html",
348+
shell: """
349+
sourmash scripts sankey --taxburst-json {input} -o {output}
350+
"""
351+
325352
rule tree_10sketches_compare_matrix:
326353
input:
327354
cmp="10sketches.cmp",

examples/tax/SRR11125891.t0.lineages.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@ name = "sourmash_plugin_betterplot"
33
description = "sourmash plugin for improved plotting/viz and cluster examination."
44
readme = "README.md"
55
requires-python = ">=3.11"
6-
version = "0.5.5"
6+
version = "0.5.6"
77

88
# note: "legacy_cgi" is currently needed for ete3, but may need to be changed on next ete release, see: https://github.com/etetoolkit/ete/issues/780
99
dependencies = ["sourmash>=4.9.4,<5", "sourmash_utils>=0.2",
1010
"matplotlib", "numpy", "scipy", "scikit-learn",
1111
"seaborn", "upsetplot", "matplotlib_venn", "pandas",
1212
"plotly", "biopython", "ete3", "kaleido", "pyqt5",
13-
"legacy_cgi", "squarify==0.4.4"]
13+
"legacy_cgi", "squarify==0.4.4", "taxburst>=0.3.1"]
1414

1515
[build-system]
1616
requires = ["setuptools>=61.0"]

src/sourmash_plugin_betterplot.py

Lines changed: 119 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from collections import defaultdict, Counter
1414
from itertools import chain, combinations
1515
import pickle
16+
import json
1617

1718
import numpy
1819
import pylab
@@ -28,6 +29,7 @@
2829
import pandas as pd
2930
import plotly.graph_objects as go
3031
import squarify
32+
import taxburst
3133

3234
# this turns off a warning in presence_filter, but results in an error in
3335
# upsetplot :sweat_smile:
@@ -38,7 +40,7 @@
3840

3941
import sourmash
4042
from sourmash import sourmash_args
41-
from sourmash.logging import debug_literal, error, notify
43+
from sourmash.logging import debug_literal, error, notify, print_results
4244
from sourmash.plugins import CommandLinePlugin
4345
import sourmash_utils
4446
from sourmash.cli.utils import (add_ksize_arg, add_moltype_args, add_scaled_arg)
@@ -211,6 +213,21 @@ def sample_d_to_idents(sample_d):
211213

212214
return xx
213215

216+
def load_taxburst_json(filename, *, normalize_counts=True):
217+
"""
218+
Load in JSON format output by taxburst.
219+
220+
Optionally normalize counts to fractions that sum to 1.
221+
"""
222+
with open(filename, 'rt') as fp:
223+
top_nodes = json.load(fp)
224+
225+
if normalize_counts:
226+
taxburst.tree_utils.normalize_tree_counts(top_nodes)
227+
228+
return top_nodes
229+
230+
214231
#
215232
# CLI plugin code
216233
#
@@ -1765,6 +1782,7 @@ def save_sankey_diagram(fig, output_file):
17651782
else:
17661783
fig.show() # Show the plot if no output file is specified
17671784

1785+
17681786
def load_lingroups(map_csv):
17691787
"""Return {full_lineage_string: human_name}."""
17701788
lin2name = {}
@@ -1774,6 +1792,7 @@ def load_lingroups(map_csv):
17741792
notify(f"loaded {len(lin2name)} lingroup names from '{map_csv}'")
17751793
return lin2name
17761794

1795+
17771796
def expand_with_ancestors_sum(rows, fraction_col):
17781797
"""Expand rows with all ancestor paths, summing children if ancestor missing."""
17791798
lineage_fracs = {row["lineage"].strip(): float(row[fraction_col])
@@ -2038,6 +2057,47 @@ def process_csv_for_sankey(input_csv, csv_type, lingroup_map=None):
20382057

20392058
return nodes, links, hover_texts
20402059

2060+
2061+
def process_taxburst_for_sankey(input_file):
2062+
nodes = [] # List of unique taxonomy nodes
2063+
node_map = {} # Map taxonomic label to index
2064+
links = [] # List of link connections with flow values
2065+
hover_texts = [] # Custom hover text for percentages
2066+
processed_lineages = set() # Tracks added lineage links
2067+
2068+
top_nodes = load_taxburst_json(input_file)
2069+
all_nodes = taxburst.tree_utils.collect_all_nodes(top_nodes)
2070+
2071+
# Process each row in the dataset
2072+
for n, node in enumerate(all_nodes):
2073+
source_label = node["name"]
2074+
2075+
if source_label not in node_map:
2076+
node_map[source_label] = len(nodes)
2077+
nodes.append(source_label)
2078+
2079+
# Iterate through children
2080+
for child_node in node.get("children", []):
2081+
percent = float(child_node["count"]) * 100
2082+
target_label = child_node["name"]
2083+
2084+
# Assign indices to nodes
2085+
if target_label not in node_map:
2086+
node_map[target_label] = len(nodes)
2087+
nodes.append(target_label)
2088+
2089+
# Create a link between source and target
2090+
links.append({
2091+
"source": node_map[source_label],
2092+
"target": node_map[target_label],
2093+
"value": percent
2094+
})
2095+
hover_texts.append(f"{source_label}{target_label}<br>{percent:.2f}%")
2096+
notify(f"loaded {n+1} nodes from '{input_file}'")
2097+
2098+
return nodes, links, hover_texts
2099+
2100+
20412101
class Command_Sankey(CommandLinePlugin):
20422102
command = 'sankey'
20432103
description = """\
@@ -2056,6 +2116,7 @@ def __init__(self, subparser):
20562116
group = subparser.add_mutually_exclusive_group(required=True)
20572117
group.add_argument("--summary-csv", type=str, help="Path to csv_summary generated by running 'sourmash tax metagenome' on a sourmash gather csv")
20582118
group.add_argument("--annotate-csv", type=str, help="Path to 'with-lineages' file generated by running 'sourmash tax annotate' on a sourmash gather csv")
2119+
group.add_argument('--taxburst-json', type=str, help="taxburst JSON output")
20592120
subparser.add_argument("--lingroups", type=str, help="Path to 'lingroups' file (lineage to lingroup mapping) to enable lingroup labeling in the Sankey diagram. Not needed if `csv_summary` was generated with `--lingroups` file provided.")
20602121

20612122
subparser.add_argument("-o", "--output", type=str, help="output file for alluvial flow diagram")
@@ -2064,25 +2125,31 @@ def __init__(self, subparser):
20642125
subparser.epilog = "You must provide either --summary-csv or --annotate-csv, but not both."
20652126

20662127
def main(self, args):
2067-
# Build info appropriately based on input file type
2068-
if args.summary_csv:
2069-
input_csv = args.summary_csv
2070-
csv_type = "csv_summary"
2071-
required_headers = ["f_weighted_at_rank", "lineage"]
2128+
if args.summary_csv or args.annotate_csv:
2129+
# Build info appropriately based on input file type
2130+
if args.summary_csv:
2131+
input_csv = args.summary_csv
2132+
csv_type = "csv_summary"
2133+
required_headers = ["f_weighted_at_rank", "lineage"]
2134+
else:
2135+
input_csv = args.annotate_csv
2136+
csv_type = "with-lineages"
2137+
required_headers = ["f_unique_weighted", "lineage"]
2138+
2139+
# Check if the required headers are present
2140+
with open(input_csv, 'r') as file:
2141+
reader = csv.DictReader(file)
2142+
if not all(header in reader.fieldnames for header in required_headers):
2143+
raise ValueError(f"Expected headers {required_headers} not found. Is this a correct file for '{csv_type}' type?")
2144+
2145+
# process csv
2146+
nodes, links, hover_texts = process_csv_for_sankey(input_csv, csv_type, lingroup_map=args.lingroups)
2147+
base_title = os.path.basename(input_csv.rsplit(".csv")[0])
2148+
elif args.taxburst_json:
2149+
nodes, links, hover_texts = process_taxburst_for_sankey(args.taxburst_json)
2150+
base_title = os.path.basename(args.taxburst_json.rsplit(".json")[0])
20722151
else:
2073-
input_csv = args.annotate_csv
2074-
csv_type = "with-lineages"
2075-
required_headers = ["f_unique_weighted", "lineage"]
2076-
2077-
# Check if the required headers are present
2078-
with open(input_csv, 'r') as file:
2079-
reader = csv.DictReader(file)
2080-
if not all(header in reader.fieldnames for header in required_headers):
2081-
raise ValueError(f"Expected headers {required_headers} not found. Is this a correct file for '{csv_type}' type?")
2082-
2083-
# process csv
2084-
nodes, links, hover_texts = process_csv_for_sankey(input_csv, csv_type, lingroup_map=args.lingroups)
2085-
base_title = os.path.basename(input_csv.rsplit(".csv")[0])
2152+
assert 0, "unhandled input format"
20862153

20872154
# Create Sankey diagram
20882155
fig = go.Figure(go.Sankey(
@@ -2283,14 +2350,17 @@ class Command_TreeMap(CommandLinePlugin):
22832350

22842351
def __init__(self, subparser):
22852352
super().__init__(subparser)
2286-
subparser.add_argument('csvfile', help='csv_summary output from tax metagenome')
2353+
subparser.add_argument('inputfile', help='input taxonomy - by default, csv_summary output from tax metagenome')
22872354
subparser.add_argument('-o', '--output', required=True,
22882355
help='output figure to this file')
22892356
subparser.add_argument('-r', '--rank', default='phylum',
22902357
help='display at this rank')
22912358
subparser.add_argument('-n', '--num-to-display', type=int,
22922359
default=25,
22932360
help="display at most these many taxa; aggregate the remainder (default: 25; 0 to display all)")
2361+
subparser.add_argument('--taxburst-json',
2362+
action='store_true',
2363+
help='input format is JSON from taxburst')
22942364

22952365

22962366
def main(self, args):
@@ -2303,24 +2373,39 @@ def plot_treemap(args):
23032373
import itertools
23042374
cmap = colormaps['viridis']
23052375

2306-
df = pd.read_csv(args.csvfile)
2376+
if not args.taxburst_json:
2377+
df = pd.read_csv(args.inputfile)
23072378

2308-
print(f"reading input file '{args.csvfile}'")
2309-
for colname in ('query_name', 'rank', 'f_weighted_at_rank', 'lineage'):
2310-
if colname not in df.columns:
2311-
print(f"input is missing column '{colname}'; is this a csv_summary file?")
2312-
sys.exit(-1)
2379+
print(f"reading input file '{args.inputfile}'")
2380+
for colname in ('query_name', 'rank', 'f_weighted_at_rank', 'lineage'):
2381+
if colname not in df.columns:
2382+
print(f"input is missing column '{colname}'; is this a csv_summary file?")
2383+
sys.exit(-1)
23132384

2314-
df = df.sort_values(by='f_weighted_at_rank')
2385+
df = df.sort_values(by='f_weighted_at_rank')
23152386

2316-
# select rank
2317-
df2 = df[df['rank'] == args.rank]
2318-
df2['name'] = df2['lineage'].apply(lambda x: x.split(';')[-1])
2387+
# select rank
2388+
df2 = df[df['rank'] == args.rank]
2389+
df2['name'] = df2['lineage'].apply(lambda x: x.split(';')[-1])
23192390

2320-
fractions = list(df2['f_weighted_at_rank'].tolist())
2321-
names = list(df2['name'].tolist())
2322-
fractions.reverse()
2323-
names.reverse()
2391+
fractions = list(df2['f_weighted_at_rank'].tolist())
2392+
names = list(df2['name'].tolist())
2393+
fractions.reverse()
2394+
names.reverse()
2395+
else:
2396+
assert args.taxburst_json
2397+
top_nodes = load_taxburst_json(args.inputfile)
2398+
2399+
all_nodes = taxburst.tree_utils.collect_all_nodes(top_nodes)
2400+
all_nodes = [ n for n in all_nodes if n["rank"] == args.rank ]
2401+
unclass = [ n for n in top_nodes if n["name"] == "unclassified" ]
2402+
if unclass:
2403+
assert len(unclass) == 1
2404+
all_nodes.append(unclass[0])
2405+
2406+
all_nodes.sort(key=lambda n: -n["count"])
2407+
fractions = [ n["count"] for n in all_nodes ]
2408+
names = [ n["name"] for n in all_nodes ]
23242409

23252410
num = max(args.num_to_display, 0) # non-negative
23262411
num = min(args.num_to_display, len(names)) # list of names

0 commit comments

Comments
 (0)