Skip to content

Commit

Permalink
refactor: clean up snapshotter
Browse files Browse the repository at this point in the history
* Rename `SnapshotIndex` to `Snapshotter`.
* Remove unused repo properties and `Snapshotter.clean` method.
* Rename snapshotter load methods.
  • Loading branch information
igboyes authored Jun 20, 2024
1 parent 7627e4c commit 9397862
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 106 deletions.
46 changes: 30 additions & 16 deletions tests/snapshot/test_index.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pytest

from virtool_cli.ref.snapshot.index import SnapshotIndex
from virtool_cli.ref.repo import EventSourcedRepo
from virtool_cli.ref.snapshot.index import Snapshotter


@pytest.fixture()
def snapshotter(scratch_repo):
return SnapshotIndex(scratch_repo.path / ".cache/snapshot")
return Snapshotter(scratch_repo.path / ".cache/snapshot")


class TestSnapshotIndex:
Expand All @@ -19,27 +20,36 @@ def test_otu_ids(self, scratch_repo, snapshotter):
for otu_id in true_otu_ids:
assert otu_id in snapshotter.id_to_taxid

def test_taxids(self, scratch_repo, snapshotter):
true_otu_taxids = [
otu.taxid for otu in scratch_repo.get_all_otus(ignore_cache=True)
]
def test_load_by_id(self, snapshotter: Snapshotter, scratch_repo: EventSourcedRepo):
"""Test that we can load an OTU by its ID."""
otu_ids = [otu.id for otu in scratch_repo.get_all_otus(ignore_cache=True)]

assert snapshotter.taxids
for otu_id in otu_ids:
assert snapshotter.load_by_id(otu_id).id == otu_id

assert len(true_otu_taxids) == len(snapshotter.taxids)
def test_load_by_taxid(
self,
scratch_repo: EventSourcedRepo,
snapshotter: Snapshotter,
):
"""Test that we can load an OTU by its taxid."""
taxids = [otu.taxid for otu in scratch_repo.get_all_otus(ignore_cache=True)]

for taxid in true_otu_taxids:
assert taxid in snapshotter.index_by_taxid
for taxid in taxids:
assert snapshotter.load_by_taxid(taxid).taxid == taxid

def test_names(self, scratch_repo, snapshotter):
def test_load_by_name(
self,
scratch_repo: EventSourcedRepo,
snapshotter: Snapshotter,
):
"""Test that we can load an OTU by its name."""
true_otu_names = [
otu.name for otu in scratch_repo.get_all_otus(ignore_cache=True)
]

assert len(true_otu_names) == len(snapshotter.index_by_name)

for name in true_otu_names:
assert name in snapshotter.index_by_name
assert snapshotter.load_by_name(name).name == name

def test_accessions(self, scratch_repo, snapshotter):
true_accessions = set()
Expand Down Expand Up @@ -69,15 +79,19 @@ class TestSnapshotIndexCaching:
],
)
def test_load_otu_by_taxid(
self, taxid: int, accessions: list[str], scratch_repo, snapshotter
self,
taxid: int,
accessions: list[str],
scratch_repo,
snapshotter,
):
scratch_repo.snapshot()

rehydrated_otu = scratch_repo.get_otu_by_taxid(taxid)

assert rehydrated_otu

snapshot_otu = snapshotter.load_otu_by_taxid(taxid)
snapshot_otu = snapshotter.load_by_taxid(taxid)

assert snapshot_otu

Expand Down
51 changes: 25 additions & 26 deletions virtool_cli/ref/otu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@


def create_otu(
repo: EventSourcedRepo, taxid: int, ignore_cache: bool = False
repo: EventSourcedRepo,
taxid: int,
ignore_cache: bool = False,
) -> EventSourcedRepoOTU:
"""Initialize a new OTU from a Taxonomy ID."""
logger = base_logger.bind(taxid=taxid)

if taxid in repo.taxids:
if repo.get_otu_by_taxid(taxid):
raise ValueError(
f"Taxonomy ID {taxid} has already been added to this reference."
f"Taxonomy ID {taxid} has already been added to this reference.",
)

ncbi = NCBIClient.from_repo(repo.path, ignore_cache)

taxonomy = ncbi.fetch_taxonomy_record(taxid)

if taxonomy is None:
logger.fatal(f"Taxonomy ID {taxid} not found")
sys.exit(1)
Expand All @@ -49,10 +52,13 @@ def create_otu(


def update_otu(
repo: EventSourcedRepo, otu: EventSourcedRepoOTU, ignore_cache: bool = False
repo: EventSourcedRepo,
otu: EventSourcedRepoOTU,
ignore_cache: bool = False,
):
"""Fetch a full list of Nucleotide accessions associated with the OTU
and pass the list to the add method."""
and pass the list to the add method.
"""
ncbi = NCBIClient.from_repo(repo.path, ignore_cache)

linked_accessions = ncbi.link_accessions_from_taxid(otu.taxid)
Expand All @@ -79,10 +85,8 @@ def group_genbank_records_by_isolate(
for source_type in IsolateNameType:
if source_type in record.source.model_fields_set:
isolate_name = IsolateName(
**{
"type": IsolateNameType(source_type),
"value": record.source.model_dump()[source_type],
},
type=IsolateNameType(source_type),
value=record.source.model_dump()[source_type],
)

isolates[isolate_name][record.accession] = record
Expand All @@ -97,10 +101,8 @@ def group_genbank_records_by_isolate(
)

isolate_name = IsolateName(
**{
"type": IsolateNameType(IsolateNameType.REFSEQ),
"value": record.accession,
},
type=IsolateNameType(IsolateNameType.REFSEQ),
value=record.accession,
)

isolates[isolate_name][record.accession] = record
Expand All @@ -119,7 +121,8 @@ def add_sequences(
ignore_cache: bool = False,
):
"""Take a list of accessions, filter for eligible accessions and
add new sequences to the OTU"""
add new sequences to the OTU
"""
client = NCBIClient.from_repo(repo.path, ignore_cache)

otu_logger = base_logger.bind(taxid=otu.taxid, otu_id=str(otu.id), name=otu.name)
Expand Down Expand Up @@ -147,7 +150,7 @@ def add_sequences(
isolate_id = otu.get_isolate_id_by_name(isolate_key)
if isolate_id is None:
otu_logger.debug(
f"Creating isolate for {isolate_key.type}, {isolate_key.value}"
f"Creating isolate for {isolate_key.type}, {isolate_key.value}",
)
isolate = repo.create_isolate(
otu_id=otu.id,
Expand Down Expand Up @@ -181,25 +184,21 @@ def add_sequences(
)

else:
otu_logger.info(f"No new sequences added to OTU")
otu_logger.info("No new sequences added to OTU")


def get_molecule_from_records(records: list[NCBIGenbank]) -> Molecule:
"""Return relevant molecule metadata from one or more records"""
for record in records:
if record.refseq:
return Molecule(
**{
"strandedness": record.strandedness.value,
"type": record.moltype.value,
"topology": record.topology.value,
}
strandedness=record.strandedness.value,
type=record.moltype.value,
topology=record.topology.value,
)

return Molecule(
**{
"strandedness": records[0].strandedness.value,
"type": records[0].moltype.value,
"topology": records[0].topology.value,
}
strandedness=records[0].strandedness.value,
type=records[0].moltype.value,
topology=records[0].topology.value,
)
Loading

0 comments on commit 9397862

Please sign in to comment.