diff --git a/tests/snapshot/test_index.py b/tests/snapshot/test_index.py index 69f16b13..150ebfac 100644 --- a/tests/snapshot/test_index.py +++ b/tests/snapshot/test_index.py @@ -1,11 +1,12 @@ import pytest -from virtool_cli.ref.snapshot.index import SnapshotIndex +from virtool_cli.ref.repo import EventSourcedRepo +from virtool_cli.ref.snapshot.index import Snapshotter @pytest.fixture() def snapshotter(scratch_repo): - return SnapshotIndex(scratch_repo.path / ".cache/snapshot") + return Snapshotter(scratch_repo.path / ".cache/snapshot") class TestSnapshotIndex: @@ -19,27 +20,36 @@ def test_otu_ids(self, scratch_repo, snapshotter): for otu_id in true_otu_ids: assert otu_id in snapshotter.id_to_taxid - def test_taxids(self, scratch_repo, snapshotter): - true_otu_taxids = [ - otu.taxid for otu in scratch_repo.get_all_otus(ignore_cache=True) - ] + def test_load_by_id(self, snapshotter: Snapshotter, scratch_repo: EventSourcedRepo): + """Test that we can load an OTU by its ID.""" + otu_ids = [otu.id for otu in scratch_repo.get_all_otus(ignore_cache=True)] - assert snapshotter.taxids + for otu_id in otu_ids: + assert snapshotter.load_by_id(otu_id).id == otu_id - assert len(true_otu_taxids) == len(snapshotter.taxids) + def test_load_by_taxid( + self, + scratch_repo: EventSourcedRepo, + snapshotter: Snapshotter, + ): + """Test that we can load an OTU by its taxid.""" + taxids = [otu.taxid for otu in scratch_repo.get_all_otus(ignore_cache=True)] - for taxid in true_otu_taxids: - assert taxid in snapshotter.index_by_taxid + for taxid in taxids: + assert snapshotter.load_by_taxid(taxid).taxid == taxid - def test_names(self, scratch_repo, snapshotter): + def test_load_by_name( + self, + scratch_repo: EventSourcedRepo, + snapshotter: Snapshotter, + ): + """Test that we can load an OTU by its name.""" true_otu_names = [ otu.name for otu in scratch_repo.get_all_otus(ignore_cache=True) ] - assert len(true_otu_names) == len(snapshotter.index_by_name) - for name in true_otu_names: - assert name in snapshotter.index_by_name + assert snapshotter.load_by_name(name).name == name def test_accessions(self, scratch_repo, snapshotter): true_accessions = set() @@ -69,7 +79,11 @@ class TestSnapshotIndexCaching: ], ) def test_load_otu_by_taxid( - self, taxid: int, accessions: list[str], scratch_repo, snapshotter + self, + taxid: int, + accessions: list[str], + scratch_repo, + snapshotter, ): scratch_repo.snapshot() @@ -77,7 +91,7 @@ def test_load_otu_by_taxid( assert rehydrated_otu - snapshot_otu = snapshotter.load_otu_by_taxid(taxid) + snapshot_otu = snapshotter.load_by_taxid(taxid) assert snapshot_otu diff --git a/virtool_cli/ref/otu.py b/virtool_cli/ref/otu.py index db2a34f7..3cc3a73d 100644 --- a/virtool_cli/ref/otu.py +++ b/virtool_cli/ref/otu.py @@ -14,19 +14,22 @@ def create_otu( - repo: EventSourcedRepo, taxid: int, ignore_cache: bool = False + repo: EventSourcedRepo, + taxid: int, + ignore_cache: bool = False, ) -> EventSourcedRepoOTU: """Initialize a new OTU from a Taxonomy ID.""" logger = base_logger.bind(taxid=taxid) - if taxid in repo.taxids: + if repo.get_otu_by_taxid(taxid): raise ValueError( - f"Taxonomy ID {taxid} has already been added to this reference." + f"Taxonomy ID {taxid} has already been added to this reference.", ) ncbi = NCBIClient.from_repo(repo.path, ignore_cache) taxonomy = ncbi.fetch_taxonomy_record(taxid) + if taxonomy is None: logger.fatal(f"Taxonomy ID {taxid} not found") sys.exit(1) @@ -49,10 +52,13 @@ def create_otu( def update_otu( - repo: EventSourcedRepo, otu: EventSourcedRepoOTU, ignore_cache: bool = False + repo: EventSourcedRepo, + otu: EventSourcedRepoOTU, + ignore_cache: bool = False, ): """Fetch a full list of Nucleotide accessions associated with the OTU - and pass the list to the add method.""" + and pass the list to the add method. + """ ncbi = NCBIClient.from_repo(repo.path, ignore_cache) linked_accessions = ncbi.link_accessions_from_taxid(otu.taxid) @@ -79,10 +85,8 @@ def group_genbank_records_by_isolate( for source_type in IsolateNameType: if source_type in record.source.model_fields_set: isolate_name = IsolateName( - **{ - "type": IsolateNameType(source_type), - "value": record.source.model_dump()[source_type], - }, + type=IsolateNameType(source_type), + value=record.source.model_dump()[source_type], ) isolates[isolate_name][record.accession] = record @@ -97,10 +101,8 @@ def group_genbank_records_by_isolate( ) isolate_name = IsolateName( - **{ - "type": IsolateNameType(IsolateNameType.REFSEQ), - "value": record.accession, - }, + type=IsolateNameType(IsolateNameType.REFSEQ), + value=record.accession, ) isolates[isolate_name][record.accession] = record @@ -119,7 +121,8 @@ def add_sequences( ignore_cache: bool = False, ): """Take a list of accessions, filter for eligible accessions and - add new sequences to the OTU""" + add new sequences to the OTU + """ client = NCBIClient.from_repo(repo.path, ignore_cache) otu_logger = base_logger.bind(taxid=otu.taxid, otu_id=str(otu.id), name=otu.name) @@ -147,7 +150,7 @@ def add_sequences( isolate_id = otu.get_isolate_id_by_name(isolate_key) if isolate_id is None: otu_logger.debug( - f"Creating isolate for {isolate_key.type}, {isolate_key.value}" + f"Creating isolate for {isolate_key.type}, {isolate_key.value}", ) isolate = repo.create_isolate( otu_id=otu.id, @@ -181,7 +184,7 @@ def add_sequences( ) else: - otu_logger.info(f"No new sequences added to OTU") + otu_logger.info("No new sequences added to OTU") def get_molecule_from_records(records: list[NCBIGenbank]) -> Molecule: @@ -189,17 +192,13 @@ def get_molecule_from_records(records: list[NCBIGenbank]) -> Molecule: for record in records: if record.refseq: return Molecule( - **{ - "strandedness": record.strandedness.value, - "type": record.moltype.value, - "topology": record.topology.value, - } + strandedness=record.strandedness.value, + type=record.moltype.value, + topology=record.topology.value, ) return Molecule( - **{ - "strandedness": records[0].strandedness.value, - "type": records[0].moltype.value, - "topology": records[0].topology.value, - } + strandedness=records[0].strandedness.value, + type=records[0].moltype.value, + topology=records[0].topology.value, ) diff --git a/virtool_cli/ref/repo.py b/virtool_cli/ref/repo.py index 3c808b26..a4a1d9bd 100644 --- a/virtool_cli/ref/repo.py +++ b/virtool_cli/ref/repo.py @@ -23,6 +23,7 @@ from orjson import orjson from structlog import get_logger +from virtool_cli.ref.event_index_cache import EventIndexCache, EventIndexCacheError from virtool_cli.ref.events import ( CreateIsolate, CreateIsolateData, @@ -48,10 +49,9 @@ EventSourcedRepoSequence, RepoMeta, ) -from virtool_cli.ref.snapshot.index import SnapshotIndex -from virtool_cli.ref.event_index_cache import EventIndexCache, EventIndexCacheError -from virtool_cli.utils.models import Molecule +from virtool_cli.ref.snapshot.index import Snapshotter from virtool_cli.ref.utils import DataType, IsolateName, pad_zeroes +from virtool_cli.utils.models import Molecule logger = get_logger("repo") @@ -78,10 +78,10 @@ def __init__(self, path: Path): self._event_index_cache = EventIndexCache(self.cache_path / "event_index") """The event index cache of the event sourced repository.""" - self._snapshotter = ( - SnapshotIndex(path=self._snapshot_path) + self._snapshotter: Snapshotter = ( + Snapshotter(path=self._snapshot_path) if self._snapshot_path.exists() - else SnapshotIndex.new(path=self._snapshot_path, metadata=self.meta) + else Snapshotter.new(path=self._snapshot_path, metadata=self.meta) ) """The snapshot index. Maintains and caches the read model of the Repo.""" @@ -149,19 +149,10 @@ def src_path(self) -> Path: """The path to the repo src directory.""" return self._event_store.path - @property - def taxids(self) -> set: - """Extant Taxonomy ids in the read model""" - return self._snapshotter.taxids - - @property - def accessions(self) -> set: - """Extant accessions in the read model""" - return self._snapshotter.accessions - def _get_event_index(self) -> dict[uuid.UUID, list[int]]: """Get the current event index from the event store, - binned and indexed by OTU Id.""" + binned and indexed by OTU Id. + """ otu_event_index = defaultdict(list) for event in self._event_store.iter_events(): @@ -171,7 +162,8 @@ def _get_event_index(self) -> dict[uuid.UUID, list[int]]: return otu_event_index def _get_event_index_after_start( - self, start: int = 1 + self, + start: int = 1, ) -> dict[uuid.UUID, list[int]]: """Get the current event index, binned and indexed by OTU ID""" otu_event_index = defaultdict(list) @@ -220,12 +212,14 @@ def create_otu( taxid: int, ): """Create an OTU.""" - if taxid in self._snapshotter.taxids: + if otu := self.get_otu_by_taxid(taxid): raise ValueError( - f"OTU already exists as {self._snapshotter.index_by_taxid[taxid]}", + f"OTU already exists as {otu}", ) + if name in self._snapshotter.index_by_name: raise ValueError(f"An OTU with the name '{name}' already exists") + if legacy_id in self._snapshotter.index_by_legacy_id: raise ValueError(f"An OTU with the legacy ID '{legacy_id}' already exists") @@ -265,13 +259,15 @@ def create_isolate( source_type: str, ) -> EventSourcedRepoIsolate | None: """Create and return a new isolate within the given OTU. - If the isolate name already exists, return None.""" + If the isolate name already exists, return None. + """ otu = self.get_otu(otu_id, ignore_cache=False) - name = IsolateName(**{"type": source_type, "value": source_name}) + name = IsolateName(type=source_type, value=source_name) if otu.get_isolate_id_by_name(name) is not None: logger.warning( - "An isolate by this name already exists", isolate_name=str(name) + "An isolate by this name already exists", + isolate_name=str(name), ) return None @@ -314,7 +310,8 @@ def create_sequence( sequence: str, ) -> EventSourcedRepoSequence | None: """Create and return a new sequence within the given OTU. - If the accession already exists in this OTU, return None.""" + If the accession already exists in this OTU, return None. + """ otu = self.get_otu(otu_id, ignore_cache=False) if accession in otu.accessions: @@ -385,16 +382,18 @@ def read_otu(self, otu_id: uuid.UUID) -> EventSourcedRepoOTU | None: """Return an OTU corresponding to a UUID if found in snapshot, else None""" logger.debug("Loading OTU from snapshot...", otu_id=str(otu_id)) - return self._snapshotter.load_otu(otu_id) + return self._snapshotter.load_by_id(otu_id) def read_otu_by_taxid(self, taxid: int) -> EventSourcedRepoOTU | None: """Return an OTU corresponding to a Taxonomy ID if found in snapshot, else None""" logger.debug("Loading OTU from snapshot...", taxid=taxid) - return self._snapshotter.load_otu_by_taxid(taxid) + return self._snapshotter.load_by_taxid(taxid) def get_otu( - self, otu_id: uuid.UUID, ignore_cache: bool = False + self, + otu_id: uuid.UUID, + ignore_cache: bool = False, ) -> EventSourcedRepoOTU | None: """Return an OTU corresponding with a given OTU Id if it exists, else None.""" logger.debug("Getting OTU from events...", otu_id=str(otu_id)) @@ -405,8 +404,16 @@ def get_otu( return None + def get_otu_by_name( + self, + name: str, + ): + return self._snapshotter.load_by_name(name) + def get_otu_by_taxid( - self, taxid: int, ignore_cache: bool = False + self, + taxid: int, + ignore_cache: bool = False, ) -> EventSourcedRepoOTU | None: """Return an OTU corresponding with a given OTU Id if it exists, else None""" if (otu_id := self._snapshotter.index_by_taxid.get(taxid)) is not None: @@ -493,10 +500,11 @@ def _get_otu_metadata(self, event_ids: list[int]) -> dict | None: } def _get_otu_events( - self, otu_id: uuid.UUID, ignore_cache: bool = False + self, + otu_id: uuid.UUID, + ignore_cache: bool = False, ) -> list[int]: - """ - Returns an up-to-date list of events associated with this OTU Id. + """Returns an up-to-date list of events associated with this OTU Id. If ignore_cache, loads the OTU's event index cache and makes sure the results are up to date before returning the list. @@ -540,7 +548,9 @@ def _get_otu_events( otu_logger.debug("Writing events to cache...", events=event_ids) self._event_index_cache.cache_otu_events( - otu_id, event_ids, last_id=self.last_id + otu_id, + event_ids, + last_id=self.last_id, ) return event_ids @@ -569,7 +579,7 @@ def _load_otu_events_from_cache_and_update(self, otu_id: uuid.UUID) -> list[int] if cached_otu_index.at_event > self.last_id: raise EventIndexCacheError( "Bad Index: " - + "Cached event index is greater than current repo's last ID" + + "Cached event index is greater than current repo's last ID", ) # Update event list @@ -583,7 +593,7 @@ def _load_otu_events_from_cache_and_update(self, otu_id: uuid.UUID) -> list[int] otu_event_list = cached_otu_index.events for event in self._event_store.iter_events_from_index( - start=cached_otu_index.at_event + start=cached_otu_index.at_event, ): if type(event) in OTU_EVENT_TYPES and event.id not in otu_event_list: otu_event_list.append(event.id) @@ -597,7 +607,9 @@ def _load_otu_events_from_cache_and_update(self, otu_id: uuid.UUID) -> list[int] ) self._event_index_cache.cache_otu_events( - otu_id, otu_event_list, last_id=self.last_id + otu_id, + otu_event_list, + last_id=self.last_id, ) return otu_event_list @@ -660,7 +672,7 @@ def iter_events_from_index(self, start: int = 1) -> Generator[Event, None, None] def read_event(self, event_id: int) -> Event: return EventStore._read_event_at_path( - self.path / f"{pad_zeroes(event_id)}.json" + self.path / f"{pad_zeroes(event_id)}.json", ) def write_event( diff --git a/virtool_cli/ref/snapshot/index.py b/virtool_cli/ref/snapshot/index.py index 889142b5..976d502b 100644 --- a/virtool_cli/ref/snapshot/index.py +++ b/virtool_cli/ref/snapshot/index.py @@ -1,16 +1,14 @@ +from collections.abc import Generator from dataclasses import dataclass from pathlib import Path -import shutil from uuid import UUID -from collections.abc import Generator import orjson from structlog import get_logger -from virtool_cli.ref.resources import RepoMeta, EventSourcedRepoOTU +from virtool_cli.ref.resources import EventSourcedRepoOTU, RepoMeta from virtool_cli.ref.snapshot.otu import OTUSnapshot - logger = get_logger() @@ -55,9 +53,8 @@ def __repr__(self): ) -class SnapshotIndex: - """Manages OTUSnapshot loading and caching, - and maintains an index of the contents.""" +class Snapshotter: + """Load and cache OTU snapshots.""" def __init__(self, path: Path): self.path = path @@ -82,7 +79,7 @@ def new(cls, path: Path, metadata: RepoMeta): with open(path / "meta.json", "wb") as f: f.write(orjson.dumps(metadata.model_dump())) - return SnapshotIndex(path) + return Snapshotter(path) @property def id_to_taxid(self) -> dict[UUID, int]: @@ -124,22 +121,10 @@ def otu_ids(self) -> set[UUID]: return set(self._index.keys()) - @property - def taxids(self) -> set[int]: - """A list of Taxonomy IDs of snapshots.""" - self._update_index() - - return set(self.index_by_taxid.keys()) - @property def accessions(self) -> set[str]: return set(self._get_accession_index().keys()) - def clean(self): - """Remove and remake snapshot cache directory""" - shutil.rmtree(self.path) - self.path.mkdir(exist_ok=True) - def snapshot( self, otus: list[EventSourcedRepoOTU], @@ -167,10 +152,13 @@ def snapshot( def iter_otus(self) -> Generator[EventSourcedRepoOTU, None, None]: """Iterate over the OTUs in the snapshot""" for otu_id in self.otu_ids: - yield self.load_otu(otu_id) + yield self.load_by_id(otu_id) def cache_otu( - self, otu: "EventSourcedRepoOTU", at_event: int | None = None, options=None + self, + otu: "EventSourcedRepoOTU", + at_event: int | None = None, + options=None, ): """Snapshots a single OTU""" logger.debug(f"Writing a snapshot for {otu.taxid}...") @@ -180,7 +168,7 @@ def cache_otu( self._index[otu.id] = OTUKeys.from_otu(otu) self._cache_index() - def load_otu(self, otu_id: UUID) -> EventSourcedRepoOTU | None: + def load_by_id(self, otu_id: UUID) -> EventSourcedRepoOTU | None: """Loads an OTU from the most recent repo snapshot""" try: otu_snap = OTUSnapshot(self.path / f"{otu_id}") @@ -189,11 +177,23 @@ def load_otu(self, otu_id: UUID) -> EventSourcedRepoOTU | None: return otu_snap.load() - def load_otu_by_taxid(self, taxid: int) -> EventSourcedRepoOTU | None: + def load_by_name(self, name: str) -> EventSourcedRepoOTU | None: + """Takes an OTU name and returns an OTU from the most recent snapshot.""" + otu_id = self.index_by_name.get(name) + + if otu_id: + return self.load_by_id(otu_id) + + return None + + def load_by_taxid(self, taxid: int) -> EventSourcedRepoOTU | None: """Takes a Taxonomy ID and returns an OTU from the most recent snapshot.""" otu_id = self.index_by_taxid[taxid] - return self.load_otu(otu_id) + if otu_id: + return self.load_by_id(otu_id) + + return None def _build_index(self) -> dict[UUID, OTUKeys]: """Build a new index from the contents of the snapshot cache directory""" @@ -205,7 +205,7 @@ def _build_index(self) -> dict[UUID, OTUKeys]: except ValueError: continue - otu = self.load_otu(otu_id) + otu = self.load_by_id(otu_id) if otu is None: raise FileNotFoundError("OTU not found") index[otu.id] = OTUKeys( @@ -253,7 +253,6 @@ def _load_index(self) -> dict | None: def _update_index(self): """Update the index in memory.""" - filename_index = {str(otu_id) for otu_id in self._index} for subpath in self.path.iterdir(): @@ -265,7 +264,7 @@ def _update_index(self): except ValueError: continue - unindexed_otu = self.load_otu(unlisted_otu_id) + unindexed_otu = self.load_by_id(unlisted_otu_id) self._index[unindexed_otu.id] = OTUKeys.from_otu(unindexed_otu)