Skip to content

Commit

Permalink
Merge pull request #1454: Implement weighted sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin authored Aug 21, 2024
2 parents 242d67f + 9a2090b commit 1e9d131
Show file tree
Hide file tree
Showing 16 changed files with 876 additions and 46 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
* A new command, `augur merge`, now allows for generalized merging of two or more metadata tables. [#1563][] (@tsibley)
* Two new commands, `augur read-file` and `augur write-file`, now allow external programs to do i/o like Augur by piping from/to these new commands. They provide handling of compression formats and newlines consistent with the rest of Augur. [#1562][] (@tsibley)
* A new debugging mode can be enabled by setting the `AUGUR_DEBUG` environment variable to `1` (or any non-empty value). Currently the only effect is to print more information about handled (i.e. anticipated) errors. For example, stack traces and parent exceptions in an exception chain are normally omitted for handled errors, but setting this env var includes them. Future debugging and troubleshooting features, like verbose operation logging, will likely also condition on this new debugging mode. [#1577][] (@tsibley)
* filter: Added the ability to use weights in subsampling. See help text of `--group-by-weights` for more information. [#1454][] (@victorlin)

### Bug Fixes

* Embedded newlines in quoted field values of metadata files read/written by many commands, annotation files read by `augur curate apply-record-annotations`, and index files written by `augur index` are now properly handled. [#1561][] [#1564][] (@tsibley)
* Output written to stderr (e.g. informational messages, warnings, errors, etc.) is now always line-buffered regardless of the Python version in use. This helps with interleaved stderr and stdout. Previously, stderr was block-buffered on Python 3.8 and line-buffered on 3.9 and higher. [#1563][] (@tsibley)

[#1454]: https://github.com/nextstrain/augur/pull/1454
[#1561]: https://github.com/nextstrain/augur/pull/1561
[#1562]: https://github.com/nextstrain/augur/pull/1562
[#1563]: https://github.com/nextstrain/augur/pull/1563
Expand Down
8 changes: 6 additions & 2 deletions augur/dates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,5 +143,9 @@ def get_numerical_dates(metadata:pd.DataFrame, name_col = None, date_col='date',
dates = metadata[date_col].astype(float)
return dict(zip(strains, dates))

def get_iso_year_week(year, month, day):
return datetime.date(year, month, day).isocalendar()[:2]
def get_year_month(year, month):
return f"{year}-{str(month).zfill(2)}"

def get_year_week(year, month, day):
year, week = datetime.date(year, month, day).isocalendar()[:2]
return f"{year}-{str(week).zfill(2)}"
32 changes: 29 additions & 3 deletions augur/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,34 @@ def register_arguments(parser):
subsample_limits_group = subsample_group.add_mutually_exclusive_group()
subsample_limits_group.add_argument('--sequences-per-group', type=int, help="subsample to no more than this number of sequences per category")
subsample_limits_group.add_argument('--subsample-max-sequences', type=int, help="subsample to no more than this number of sequences; can be used without the group_by argument")
group_size_options = subsample_group.add_mutually_exclusive_group()
group_size_options.add_argument('--probabilistic-sampling', action='store_true', help="Allow probabilistic sampling during subsampling. This is useful when there are more groups than requested sequences. This option only applies when `--subsample-max-sequences` is provided.")
group_size_options.add_argument('--no-probabilistic-sampling', action='store_false', dest='probabilistic_sampling')
probabilistic_sampling_group = subsample_group.add_mutually_exclusive_group()
probabilistic_sampling_group.add_argument('--probabilistic-sampling', action='store_true', help="Allow probabilistic sampling during subsampling. This is useful when there are more groups than requested sequences. This option only applies when `--subsample-max-sequences` is provided.")
probabilistic_sampling_group.add_argument('--no-probabilistic-sampling', action='store_false', dest='probabilistic_sampling')
subsample_group.add_argument('--group-by-weights', type=str, metavar="FILE", help="""
TSV file defining weights for grouping. Requirements:
(1) Lines starting with '#' are treated as comment lines.
(2) The first non-comment line must be a header row.
(3) There must be a numeric ``weight`` column (weights can take on any
non-negative values).
(4) Other columns must be a subset of columns used in ``--group-by``,
with combinations of values covering all combinations present in the
metadata.
(5) This option only applies when ``--group-by`` and
``--subsample-max-sequences`` are provided.
(6) This option cannot be used with ``--no-probabilistic-sampling``.
Notes:
(1) Any ``--group-by`` columns absent from this file will be given equal
weighting across all values *within* groups defined by the other
weighted columns.
(2) An entry with the value ``default`` under all columns will be
treated as the default weight for specific groups present in the
metadata but missing from the weights file. If there is no default
weight and the metadata contains rows that are not covered by the
given weights, augur filter will exit with an error.
""")
subsample_group.add_argument('--priority', type=str, help="""tab-delimited file with list of priority scores for strains (e.g., "<strain>\\t<priority>") and no header.
When scores are provided, Augur converts scores to floating point values, sorts strains within each subsampling group from highest to lowest priority, and selects the top N strains per group where N is the calculated or requested number of strains per group.
Higher numbers indicate higher priority.
Expand All @@ -81,6 +106,7 @@ def register_arguments(parser):
output_group.add_argument('--output-metadata', help="metadata for strains that passed filters")
output_group.add_argument('--output-strains', help="list of strains that passed filters (no header)")
output_group.add_argument('--output-log', help="tab-delimited file with one row for each filtered strain and the reason it was filtered. Keyword arguments used for a given filter are reported in JSON format in a `kwargs` column.")
output_group.add_argument('--output-group-by-sizes', help="tab-delimited file one row per group with target size.")
output_group.add_argument(
'--empty-output-reporting',
type=EmptyOutputReportingMethod.argtype,
Expand Down
55 changes: 33 additions & 22 deletions augur/filter/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from . import include_exclude_rules
from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs
from .include_exclude_rules import apply_filters, construct_filters
from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_probabilistic_group_sizes, create_queues_by_group, get_groups_for_subsampling
from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_probabilistic_group_sizes, create_queues_by_group, get_groups_for_subsampling, get_weighted_group_sizes


def run(args):
Expand Down Expand Up @@ -264,32 +264,43 @@ def run(args):
# group. Then, we need to make a second pass through the metadata to find
# the requested number of records.
if args.subsample_max_sequences and records_per_group is not None:
# Calculate sequences per group. If there are more groups than maximum
# sequences requested, sequences per group will be a floating point
# value and subsampling will be probabilistic.
try:
sequences_per_group, probabilistic_used = calculate_sequences_per_group(
args.subsample_max_sequences,
records_per_group.values(),
args.probabilistic_sampling,
)
except TooManyGroupsError as error:
raise AugurError(error)

if queues_by_group is None:
# We know all of the possible groups now from the first pass through
# the metadata, so we can create queues for all groups at once.
if (probabilistic_used):
print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.")
group_sizes = get_probabilistic_group_sizes(
records_per_group.keys(),
sequences_per_group,
random_seed=args.subsample_seed,
if args.group_by_weights:
print_err(f"Sampling with weights defined by {args.group_by_weights}.")
group_sizes = get_weighted_group_sizes(
records_per_group,
group_by,
args.group_by_weights,
args.subsample_max_sequences,
args.output_group_by_sizes,
args.subsample_seed,
)
else:
print_err(f"Sampling at {sequences_per_group} per group.")
assert type(sequences_per_group) is int
group_sizes = {group: sequences_per_group for group in records_per_group.keys()}
# Calculate sequences per group. If there are more groups than maximum
# sequences requested, sequences per group will be a floating point
# value and subsampling will be probabilistic.
try:
sequences_per_group, probabilistic_used = calculate_sequences_per_group(
args.subsample_max_sequences,
records_per_group.values(),
args.probabilistic_sampling,
)
except TooManyGroupsError as error:
raise AugurError(error)

if (probabilistic_used):
print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.")
group_sizes = get_probabilistic_group_sizes(
records_per_group.keys(),
sequences_per_group,
random_seed=args.subsample_seed,
)
else:
print_err(f"Sampling at {sequences_per_group} per group.")
assert type(sequences_per_group) is int
group_sizes = {group: sequences_per_group for group in records_per_group.keys()}
queues_by_group = create_queues_by_group(group_sizes)

# Make a second pass through the metadata, only considering records that
Expand Down
Loading

0 comments on commit 1e9d131

Please sign in to comment.