diff --git a/augur/io/metadata.py b/augur/io/metadata.py index 8f9d4a711..e003cde21 100644 --- a/augur/io/metadata.py +++ b/augur/io/metadata.py @@ -1,5 +1,6 @@ import csv import pandas as pd +import pyfastx from augur.errors import AugurError from augur.io.print import print_err @@ -182,3 +183,64 @@ def read_table_to_dict(table, duplicate_reporting=DataErrorMethod.ERROR_FIRST, i raise AugurError(duplicates_message) else: raise ValueError(f"Encountered unhandled duplicate reporting method: {duplicate_reporting!r}") + + +def read_metadata_with_sequences(metadata, fasta, seq_id_column, seq_field='sequence'): + """ + Read rows from *metadata* file and yield each row as a single dict that has + been updated with their corresponding sequence from the *fasta* file. + Matches the metadata record with sequences using the sequence id provided + in the *seq_id_column*. To ensure that the sequences can be matched with + the metadata, the FASTA headers must contain the matching sequence id. The + FASTA headers may include additional description parts after the id, but + they will not be used to match the metadata. + + Note that metadata and sequence records that do not have a matching record + are skipped. + + Reads the *fasta* file with `pyfastx.Fasta`, which creates an index for + the file to allow random access of sequences via the sequence id. + See pyfastx docs for more details: + https://pyfastx.readthedocs.io/en/latest/usage.html#fasta + + Parameters + ---------- + metadata: str + Path to a CSV or TSV metadata file + + fasta: str + Path to a plain or gzipped FASTA file + + seq_id_column: str + The column in the metadata file that contains the sequence id for + matching sequences + + seq_field: str, optional + The field name to use for the sequence in the updated record + + Yields + ------ + dict + The parsed metadata record with the sequence + """ + sequences = pyfastx.Fasta(fasta) + sequence_ids = set(sequences.keys()) + + # Silencing duplicate reporting here because we will need to handle duplicates + # in both the metadata and FASTA files after processing all the records here. + for record in read_table_to_dict(metadata, duplicate_reporting=DataErrorMethod.SILENT): + seq_id = record.get(seq_id_column) + + if seq_id is None: + raise AugurError(f"The provided sequence id column {seq_id_column!r} does not exist in the metadata.") + + # Skip records that do not have a matching sequence + # TODO: change this to try/except to fetch sequences and catch + # KeyError for non-existing sequences when https://github.com/lmdu/pyfastx/issues/50 is resolved + if seq_id not in sequence_ids: + continue + + sequence_record = sequences[seq_id] + record[seq_field] = str(sequence_record.seq).upper() + + yield record diff --git a/setup.py b/setup.py index b0c0a0158..9c33a96a1 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,8 @@ "packaging >=19.2", "pandas >=1.0.0, ==1.*", "phylo-treetime >=0.9.3, ==0.9.*", - "xopen >=1.0.1, ==1.*" + "xopen >=1.0.1, ==1.*", + "pyfastx >=0.8.4, ==0.8.*" ], extras_require = { 'dev': [ diff --git a/tests/io/test_metadata.py b/tests/io/test_metadata.py index 4fd749abd..5a1fad409 100644 --- a/tests/io/test_metadata.py +++ b/tests/io/test_metadata.py @@ -1,7 +1,8 @@ import pytest +import shutil from augur.errors import AugurError -from augur.io.metadata import read_table_to_dict +from augur.io.metadata import read_table_to_dict, read_metadata_with_sequences from augur.types import DataErrorMethod @@ -85,3 +86,81 @@ def test_read_table_to_dict_with_duplicate_and_bad_id(self, metadata_with_duplic with pytest.raises(AugurError) as e_info: list(read_table_to_dict(metadata_with_duplicate, id_column=id_column)) assert str(e_info.value) == f"The provided id column {id_column!r} does not exist in {metadata_with_duplicate!r}." + + +@pytest.fixture +def fasta_file(tmpdir): + path = str(tmpdir / 'sequences.fasta') + with open(path, 'w') as fh: + fh.writelines([ + '>SEQ_A\nAAAA\n', + '>SEQ_T\nTTTT\n', + '>SEQ_C\nCCCC\n', + '>SEQ_G\nGGGG\n' + ]) + return path + +@pytest.fixture +def metadata_file(tmpdir): + path = str(tmpdir / 'metadata.tsv') + with open(path, 'w') as fh: + fh.writelines([ + 'strain\tcountry\tdate\n', + 'SEQ_A\tUSA\t2020-10-01\n', + 'SEQ_T\tUSA\t2020-10-02\n', + 'SEQ_C\tUSA\t2020-10-03\n', + 'SEQ_G\tUSA\t2020-10-04\n' + ]) + return path + +@pytest.fixture +def fasta_file_with_unmatched(tmpdir, fasta_file): + path = str(tmpdir / 'extra-sequences.fasta') + shutil.copy(fasta_file, path) + with open(path, 'a') as fh: + fh.writelines([ + ">SEQ_EXTRA_A\nAAAA\n", + ">SEQ_EXTRA_T\nTTTT\n" + ]) + return path + +@pytest.fixture +def metadata_file_with_unmatched(tmpdir, metadata_file): + path = str(tmpdir / 'extra-metadata.tsv') + shutil.copy(metadata_file, path) + with open(path, 'a') as fh: + fh.writelines([ + "SEQ_EXTRA_1\tUSA\t2020-10-01\n", + "SEQ_EXTRA_2\tUSA\t2020-10-02\n" + ]) + return path + +class TestReadMetadataWithSequence: + def test_read_metadata_with_sequence(self, metadata_file, fasta_file): + records = list(read_metadata_with_sequences(metadata_file, fasta_file, 'strain')) + assert len(records) == 4 + for record in records: + seq_base = record['strain'].split("_")[-1].upper() + expected_sequence = seq_base * 4 + assert record['sequence'] == expected_sequence + + def test_read_metadata_with_sequences_with_bad_id(self, metadata_file, fasta_file): + id_field = "bad_id" + with pytest.raises(AugurError) as e_info: + next(read_metadata_with_sequences(metadata_file, fasta_file, id_field)) + assert str(e_info.value) == f"The provided sequence id column {id_field!r} does not exist in the metadata." + + def test_read_metadata_with_sequences_with_unmatched_sequences(self, metadata_file, fasta_file_with_unmatched): + records = list(read_metadata_with_sequences(metadata_file, fasta_file_with_unmatched, 'strain')) + assert len(records) == 4 + assert [record['strain'] for record in records] == ['SEQ_A', 'SEQ_T', 'SEQ_C', 'SEQ_G'] + + def test_read_metadata_with_sequences_with_unmatched_metadata(self, metadata_file_with_unmatched, fasta_file): + records = list(read_metadata_with_sequences(metadata_file_with_unmatched, fasta_file, 'strain')) + assert len(records) == 4 + assert [record['strain'] for record in records] == ['SEQ_A', 'SEQ_T', 'SEQ_C', 'SEQ_G'] + + def test_read_metadata_with_sequences_with_unmatched_in_both(self, metadata_file_with_unmatched, fasta_file_with_unmatched): + records = list(read_metadata_with_sequences(metadata_file_with_unmatched, fasta_file_with_unmatched, 'strain')) + assert len(records) == 4 + assert [record['strain'] for record in records] == ['SEQ_A', 'SEQ_T', 'SEQ_C', 'SEQ_G']