From ee1a62b7f21c4cc2d0351cdb0177847dcf45f179 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 7 Jul 2023 13:51:22 -0700 Subject: [PATCH] [wip] Add timing output to --debug for long-running functions --- augur/filter/dates.py | 2 ++ augur/filter/debug.py | 20 ++++++++++++++++++++ augur/filter/io.py | 3 +++ augur/filter/subsample.py | 2 ++ 4 files changed, 27 insertions(+) create mode 100644 augur/filter/debug.py diff --git a/augur/filter/dates.py b/augur/filter/dates.py index 53268330c..9b25a6174 100644 --- a/augur/filter/dates.py +++ b/augur/filter/dates.py @@ -3,11 +3,13 @@ from augur.dates import get_numerical_date_from_value from augur.dates.errors import InvalidDate from augur.errors import AugurError +from augur.filter.debug import add_debugging from augur.io.metadata import METADATA_DATE_COLUMN from augur.io.sqlite3 import Sqlite3Database, sanitize_identifier from . import constants +@add_debugging def parse_dates(): """Validate dates and create a date table.""" # First, determine if there is a date column. diff --git a/augur/filter/debug.py b/augur/filter/debug.py new file mode 100644 index 000000000..a31c16479 --- /dev/null +++ b/augur/filter/debug.py @@ -0,0 +1,20 @@ +from functools import wraps +import time +from typing import Callable + +from . import constants + + +def add_debugging(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + if constants.RUNTIME_DEBUG: + start_time = time.perf_counter() + result = func(*args, **kwargs) + end_time = time.perf_counter() + total_time = end_time - start_time + print(f'Function {func.__name__} took {total_time:.4f} seconds') + return result + else: + return func(*args, **kwargs) + return wrapper diff --git a/augur/filter/io.py b/augur/filter/io.py index 7b1af05fa..57e55cce1 100644 --- a/augur/filter/io.py +++ b/augur/filter/io.py @@ -2,6 +2,7 @@ from typing import Iterable, Sequence, Set from tempfile import NamedTemporaryFile from augur.errors import AugurError +from augur.filter.debug import add_debugging from augur.index import ( index_sequences, index_vcf, @@ -79,6 +80,7 @@ def _import_tabular_file(file: TabularFile, db: Sqlite3Database, table: str): db.insert(table, file.columns, file.rows()) +@add_debugging def import_metadata(path: str, id_columns: Sequence[str], delimiters: Iterable[str]): # Initialize metadata object. try: @@ -246,6 +248,7 @@ def _create_output_table(): """) +@add_debugging def _read_and_output_sequences(args): """Read sequences and output all that passed filtering. """ diff --git a/augur/filter/subsample.py b/augur/filter/subsample.py index 55cbb3966..e19d3b5e7 100644 --- a/augur/filter/subsample.py +++ b/augur/filter/subsample.py @@ -1,6 +1,7 @@ import numpy as np from typing import Collection, Iterable, List, Sequence, Set from augur.errors import AugurError +from augur.filter.debug import add_debugging from augur.io.metadata import METADATA_DATE_COLUMN from augur.io.print import print_err from augur.io.sqlite3 import Sqlite3Database, sanitize_identifier @@ -278,6 +279,7 @@ def _calculate_fractional_sequences_per_group( return (lo + hi) / 2 +@add_debugging def apply_subsampling(args): """Apply subsampling to update the filter reason table.