Skip to content

Commit

Permalink
sqlite: revert report order change
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin committed Jun 23, 2023
1 parent 31aa9dd commit 5a9e0cc
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 21 deletions.
2 changes: 1 addition & 1 deletion augur/filter/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def run(args: argparse.Namespace):

write_outputs(args)

print_report(args)
print_report(args, exclude_by, include_by)

# TODO: The current implementation assumes the database file is hidden from
# the user. If this ever changes, clean the database of any
Expand Down
6 changes: 3 additions & 3 deletions augur/filter/include_exclude_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def apply_exclusions(exclude_by: List[FilterOption]):

sql_parameters = {
'filter_reason': exclude_function.__name__,
'filter_reason_kwargs': _filter_kwargs_to_str(kwargs)
'filter_reason_kwargs': filter_kwargs_to_str(kwargs)
}

# Add parameters returned from the filter function.
Expand Down Expand Up @@ -674,7 +674,7 @@ def apply_force_inclusions(include_by: List[FilterOption]):

sql_parameters = {
'filter_reason': include_function.__name__,
'filter_reason_kwargs': _filter_kwargs_to_str(kwargs)
'filter_reason_kwargs': filter_kwargs_to_str(kwargs)
}

# Add parameters returned from the filter function.
Expand All @@ -684,7 +684,7 @@ def apply_force_inclusions(include_by: List[FilterOption]):
db.connection.execute(sql, sql_parameters)


def _filter_kwargs_to_str(kwargs: FilterFunctionKwargs):
def filter_kwargs_to_str(kwargs: FilterFunctionKwargs):
"""Convert a dictionary of kwargs to a JSON string for downstream reporting.
This structured string can be converted back into a Python data structure
Expand Down
32 changes: 15 additions & 17 deletions augur/filter/report.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
from typing import List, Tuple
from typing import Dict, List
from augur.errors import AugurError
from augur.io.print import print_err
from augur.io.sqlite3 import Sqlite3Database
from augur.types import EmptyOutputReportingMethod
from . import constants, include_exclude_rules


def print_report(args):
def print_report(args, exclude_by, include_by):
"""Print a report of how many strains were dropped and reasoning."""
num_excluded_by_lack_of_metadata = _get_num_excluded_by_lack_of_metadata()
num_metadata_strains = _get_num_metadata_strains()
Expand Down Expand Up @@ -41,15 +41,20 @@ def print_report(args):
include_exclude_rules.force_include_where.__name__: "{count} sequences were force-included because of '{include_where}'",
}

all_filters: List[include_exclude_rules.FilterOption] = exclude_by + include_by

filter_counts = _get_filter_counts()
for filter_name, filter_kwargs, count in filter_counts:
for filter_function, filter_kwargs in all_filters:
if filter_function.__name__ not in filter_counts:
continue

if filter_kwargs:
parameters = dict(json.loads(filter_kwargs))
parameters = dict(json.loads(include_exclude_rules.filter_kwargs_to_str(filter_kwargs)))
else:
parameters = {}

parameters["count"] = count
print("\t" + report_template_by_filter_name[filter_name].format(**parameters))
parameters["count"] = filter_counts[filter_function.__name__]
print("\t" + report_template_by_filter_name[filter_function.__name__].format(**parameters))

if (args.group_by and args.sequences_per_group) or args.subsample_max_sequences:
seed_txt = ", using seed {}".format(args.subsample_seed) if args.subsample_seed else ""
Expand Down Expand Up @@ -111,26 +116,19 @@ def _get_total_strains_passed() -> int:
return result.fetchone()[0]


def _get_filter_counts() -> List[Tuple[str, str, int]]:
"""Returns a list of tuples for each filter function that had an effect.
Each tuple has:
1. Name of the filter function
2. Arguments given to the filter function
3. Number of strains included/excluded by the filter function
Sort for reproducible output.
def _get_filter_counts() -> Dict[str, int]:
"""
Returns a mapping of filter function name to the number of strains included/excluded by it.
"""
with Sqlite3Database(constants.RUNTIME_DB_FILE) as db:
result = db.connection.execute(f"""
SELECT
{constants.FILTER_REASON_COLUMN},
{constants.FILTER_REASON_KWARGS_COLUMN},
COUNT(*)
FROM {constants.FILTER_REASON_TABLE}
WHERE {constants.FILTER_REASON_COLUMN} IS NOT NULL
AND {constants.FILTER_REASON_COLUMN} != '{constants.SUBSAMPLE_FILTER_REASON}'
GROUP BY {constants.FILTER_REASON_COLUMN}, {constants.FILTER_REASON_KWARGS_COLUMN}
ORDER BY {constants.FILTER_REASON_COLUMN}
""")
return result.fetchall()
return {row[0]: row[1] for row in result.fetchall()}

0 comments on commit 5a9e0cc

Please sign in to comment.