diff --git a/augur/dates/__init__.py b/augur/dates/__init__.py index 046638826..0c44168c0 100644 --- a/augur/dates/__init__.py +++ b/augur/dates/__init__.py @@ -126,9 +126,9 @@ def get_numerical_date_from_value(value, fmt=None, min_max_year=None): except: return None -def get_numerical_dates(metadata:pd.DataFrame, name_col = None, date_col='date', fmt=None, min_max_year=None): - if not isinstance(metadata, pd.DataFrame): - raise AugurError("Metadata should be a pandas.DataFrame.") +def get_numerical_dates(metadata, name_col = None, date_col='date', fmt=None, min_max_year=None): + # if not isinstance(metadata, pd.DataFrame): + # raise AugurError("Metadata should be a pandas.DataFrame.") if fmt: strains = metadata.index.values dates = metadata[date_col].apply( @@ -136,12 +136,12 @@ def get_numerical_dates(metadata:pd.DataFrame, name_col = None, date_col='date', date, fmt, min_max_year - ) + ), meta=(date_col, 'str') ).values else: strains = metadata.index.values dates = metadata[date_col].astype(float) - return dict(zip(strains, dates)) + return dict(zip(strains.compute(), dates.compute())) def get_iso_year_week(year, month, day): return datetime.date(year, month, day).isocalendar()[:2] diff --git a/augur/filter/_run.py b/augur/filter/_run.py index e1ee97659..aa685e5e7 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -4,6 +4,7 @@ import json import os import pandas as pd +import dask.dataframe as dd from tempfile import NamedTemporaryFile from augur.errors import AugurError @@ -14,7 +15,7 @@ DELIMITER as SEQUENCE_INDEX_DELIMITER, ) from augur.io.file import PANDAS_READ_CSV_OPTIONS, open_file -from augur.io.metadata import InvalidDelimiter, Metadata, read_metadata +from augur.io.metadata import InvalidDelimiter, Metadata from augur.io.sequences import read_sequences, write_sequences from augur.io.print import print_err from augur.io.vcf import is_vcf as filename_is_vcf, write_vcf @@ -97,20 +98,17 @@ def run(args): ) useful_metadata_columns = get_useful_metadata_columns(args, metadata_object.id_column, metadata_object.columns) - metadata = read_metadata( + metadata = dd.read_csv( args.metadata, - delimiters=[metadata_object.delimiter], - columns=useful_metadata_columns, - id_columns=[metadata_object.id_column], - dtype={col: 'category' for col in useful_metadata_columns}, - ) - - duplicate_strains = metadata.index[metadata.index.duplicated()] - if len(duplicate_strains) > 0: - raise AugurError(f"The following strains are duplicated in '{args.metadata}':\n" + "\n".join(sorted(duplicate_strains))) + delimiter=metadata_object.delimiter, + usecols=useful_metadata_columns, + dtype='str', + ).set_index(metadata_object.id_column) - # FIXME: remove redundant variable from chunking logic - metadata_strains = set(metadata.index.values) + # FIXME: detect duplicates + # duplicate_strains = metadata.index[metadata.index.duplicated()] + # if len(duplicate_strains) > 0: + # raise AugurError(f"The following strains are duplicated in '{args.metadata}':\n" + "\n".join(sorted(duplicate_strains))) # Setup filters. exclude_by, include_by = construct_filters( @@ -261,16 +259,17 @@ def run(args): args.metadata_id_columns, args.output_metadata, args.output_strains, valid_strains) + # FIXME: inspect metadata/sequence mismatch # Calculate the number of strains that don't exist in either metadata or # sequences. num_excluded_by_lack_of_metadata = 0 - if sequence_strains: - num_excluded_by_lack_of_metadata = len(sequence_strains - metadata_strains) + # if sequence_strains: + # num_excluded_by_lack_of_metadata = len(sequence_strains - metadata_strains) # Calculate the number of strains passed and filtered. total_strains_passed = len(valid_strains) - total_strains_filtered = len(metadata_strains) + num_excluded_by_lack_of_metadata - total_strains_passed + total_strains_filtered = len(metadata.index) + num_excluded_by_lack_of_metadata - total_strains_passed print(f"{total_strains_filtered} {'strain was' if total_strains_filtered == 1 else 'strains were'} dropped during filtering") diff --git a/augur/filter/include_exclude_rules.py b/augur/filter/include_exclude_rules.py index a8de9aba7..47eb3eeb4 100644 --- a/augur/filter/include_exclude_rules.py +++ b/augur/filter/include_exclude_rules.py @@ -341,7 +341,7 @@ def filter_by_min_date(metadata, date_column, min_date) -> FilterFunctionReturn: ['strain1', 'strain2'] """ - strains = set(metadata.index.values) + strains = set(metadata.index.values.compute()) # Skip this filter if the date column does not exist. if date_column not in metadata.columns: @@ -766,7 +766,7 @@ def apply_filters(metadata, exclude_by: List[FilterOption], include_by: List[Fil [{'strain': 'strain2', 'filter': 'force_include_where', 'kwargs': '[["include_where", "region=Europe"]]'}] """ - strains_to_keep = set(metadata.index.values) + strains_to_keep = set(metadata.index.values.compute()) strains_to_filter = [] strains_to_force_include = [] distinct_strains_to_force_include: Set = set() diff --git a/augur/filter/subsample.py b/augur/filter/subsample.py index 3b9bfc651..fad890e23 100644 --- a/augur/filter/subsample.py +++ b/augur/filter/subsample.py @@ -27,19 +27,22 @@ def subsample(metadata, args, group_by): metadata, group_by, ) + + def apply_priorities(row): + return priorities[row.name] # Enrich with priorities. - grouping_metadata['priority'] = [priorities[strain] for strain in grouping_metadata.index] + grouping_metadata['priority'] = grouping_metadata.apply(apply_priorities, axis=1, meta=('priority', 'f8')) pandas_groupby = grouping_metadata.groupby(list(group_by), group_keys=False) - n_groups = len(pandas_groupby.groups) + n_groups = len(pandas_groupby.size()) # Determine sequences per group. if args.sequences_per_group: sequences_per_group = args.sequences_per_group elif args.subsample_max_sequences: - group_sizes = [len(strains) for strains in pandas_groupby.groups.values()] + group_sizes = pandas_groupby.size().compute().tolist() try: # Calculate sequences per group. If there are more groups than maximum diff --git a/setup.py b/setup.py index efc46cd06..931bf81b2 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,8 @@ # TODO: Remove biopython >= 1.80 pin if it is added to bcbio-gff: https://github.com/chapmanb/bcbb/issues/142 "biopython >=1.80, ==1.*", "cvxopt >=1.1.9, ==1.*", + "dask[dataframe]", + "pyarrow", "importlib_resources >=5.3.0; python_version < '3.11'", "isodate ==0.6.*", "jsonschema >=3.0.0, ==3.*",