Skip to content

Commit

Permalink
test: consolidate fixtures
Browse files Browse the repository at this point in the history
* Remove fixtures with roughly duplicate functionality.
* Rename some fixtures for consistency and accuracy.
* Add typing.
* Some `ruff` auto-formatting got into these changes.
  • Loading branch information
igboyes authored Apr 16, 2024
1 parent fb50df9 commit 98cba92
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 115 deletions.
58 changes: 25 additions & 33 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import pytest

from virtool_cli.ref.init import init_reference
from virtool_cli.ncbi.cache import NCBICache
from virtool_cli.ncbi.client import NCBIClient
from virtool_cli.ref.init import init_reference


@pytest.fixture()
Expand All @@ -13,53 +14,44 @@ def test_files_path():


@pytest.fixture()
def src_malformed_path(test_files_path: Path) -> Path:
return test_files_path / "src_malformed"


@pytest.fixture()
def src_test_path(test_files_path: Path) -> Path:
return test_files_path / "src_test"


@pytest.fixture()
def src_scratch_path(src_test_path: Path, tmp_path: Path) -> Path:
path = tmp_path / "src_scratch"
shutil.copytree(src_test_path, path)

yield path

shutil.rmtree(path)


@pytest.fixture()
def scratch_path(src_test_path: Path, cache_example_path: Path, tmp_path: Path) -> Path:
def scratch_path(
test_files_path: Path,
tmp_path: Path,
) -> Path:
"""The path to a scratch reference repository."""
path = tmp_path / "reference"

init_reference(path)

shutil.copytree(src_test_path, path / "src", dirs_exist_ok=True)
shutil.copytree(cache_example_path, path / ".cache", dirs_exist_ok=True)
shutil.copytree(test_files_path / "src_test", path / "src", dirs_exist_ok=True)
shutil.copytree(test_files_path / "cache_test", path / ".cache", dirs_exist_ok=True)

yield path

shutil.rmtree(path)


@pytest.fixture()
def cache_example_path(test_files_path: Path) -> Path:
return test_files_path / "cache_test"
def scratch_src_path(scratch_path: Path) -> Path:
"""The source path of a scratch reference repository."""
return scratch_path / "src"


@pytest.fixture()
def cache_scratch_path(cache_example_path, tmp_path: Path) -> Path:
path = tmp_path / "cache_scratch"
shutil.copytree(cache_example_path, path)
def scratch_ncbi_cache_path(
scratch_path: Path,
) -> Path:
"""The path to a scratch NCBI client cache."""
return scratch_path / ".cache"

yield path

shutil.rmtree(path)
@pytest.fixture()
def scratch_ncbi_cache(scratch_ncbi_cache_path: Path):
"""A scratch NCBI cache with preloaded data."""
return NCBICache(scratch_ncbi_cache_path)


@pytest.fixture()
def scratch_client(cache_scratch_path):
return NCBIClient(cache_scratch_path, ignore_cache=False)
def scratch_ncbi_client(scratch_ncbi_cache_path: Path):
"""A scratch NCBI client with a preloaded cache."""
return NCBIClient(scratch_ncbi_cache_path, ignore_cache=False)
11 changes: 7 additions & 4 deletions tests/test_group_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ class TestGroupRecords:
],
)
def test_group_records_by_isolate_success(
self, accessions, scratch_client, snapshot: SnapshotAssertion
self,
accessions,
scratch_ncbi_client,
snapshot: SnapshotAssertion,
):
records = scratch_client.fetch_genbank_records(accessions)
records = scratch_ncbi_client.fetch_genbank_records(accessions)

assert records

Expand All @@ -29,8 +32,8 @@ def test_group_records_by_isolate_success(
assert grouped_records[source_key] == snapshot

@pytest.mark.parametrize("accessions", [["Y11023"]])
def test_group_records_by_isolate_failure(self, accessions, scratch_client):
records = scratch_client.fetch_genbank_records(accessions)
def test_group_records_by_isolate_failure(self, accessions, scratch_ncbi_client):
records = scratch_ncbi_client.fetch_genbank_records(accessions)

assert records

Expand Down
51 changes: 33 additions & 18 deletions tests/test_ncbi_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from syrupy import SnapshotAssertion
from pathlib import Path

import pytest
from syrupy import SnapshotAssertion

from virtool_cli.ncbi.cache import NCBICache

Expand All @@ -12,12 +13,12 @@ def empty_cache_path(tmp_path):


def get_test_record(accession: str, cache_example_path) -> dict:
with open(cache_example_path / "genbank" / f"{accession}.json", "r") as f:
with open(cache_example_path / "genbank" / f"{accession}.json") as f:
return json.load(f)


def get_test_taxonomy(taxon_id: int, cache_example_path) -> dict:
with open(cache_example_path / "taxonomy" / f"{taxon_id}.json", "r") as f:
with open(cache_example_path / "taxonomy" / f"{taxon_id}.json") as f:
return json.load(f)


Expand All @@ -35,8 +36,8 @@ def test_cache_init(empty_cache_path):
assert cache._taxonomy_path.exists()


def test_cache_clear(cache_scratch_path):
cache = NCBICache(path=cache_scratch_path)
def test_cache_clear(scratch_ncbi_cache_path):
cache = NCBICache(path=scratch_ncbi_cache_path)

assert list(cache._genbank_path.glob("*.json")) != []
assert list(cache._taxonomy_path.glob("*.json")) != []
Expand All @@ -56,50 +57,64 @@ def test_cache_clear(cache_scratch_path):
)
class TestCacheGenbankOperations:
def test_cache_genbank_load_record_batch(
self, accessions, cache_scratch_path, snapshot: SnapshotAssertion
self,
accessions,
scratch_ncbi_cache_path,
snapshot: SnapshotAssertion,
):
scratch_cache = NCBICache(cache_scratch_path)
scratch_ncbi_cache = NCBICache(scratch_ncbi_cache_path)

for accession in accessions:
record = scratch_cache.load_genbank_record(accession)
record = scratch_ncbi_cache.load_genbank_record(accession)

assert record == snapshot(name=f"{accession}.json")

def test_cache_genbank_cache_records(
self, accessions, cache_example_path, empty_cache_path
self,
accessions,
scratch_ncbi_cache_path: Path,
empty_cache_path,
):
assert not empty_cache_path.exists()

cache = NCBICache(empty_cache_path)

for accession in accessions:
record = get_test_record(accession, cache_example_path)
record = get_test_record(accession, scratch_ncbi_cache_path)

cache.cache_genbank_record(data=record, accession=accession)

assert (cache._genbank_path / f"{accession}.json").exists()


@pytest.mark.parametrize("fake_accession", ["afjshd", "23222", "wheelhouse"])
def test_cache_genbank_load_fail(fake_accession, cache_scratch_path):
scratch_cache = NCBICache(cache_scratch_path)
def test_cache_genbank_load_fail(fake_accession, scratch_ncbi_cache_path):
scratch_ncbi_cache = NCBICache(scratch_ncbi_cache_path)

assert scratch_cache.load_genbank_record(fake_accession) is None
assert scratch_ncbi_cache.load_genbank_record(fake_accession) is None


@pytest.mark.parametrize("taxid", (270478, 438782, 1198450))
class TestCacheTaxonomyOperations:
def test_cache_taxonomy_load(
self, taxid, cache_scratch_path, snapshot: SnapshotAssertion
self,
taxid,
scratch_ncbi_cache_path,
snapshot: SnapshotAssertion,
):
scratch_cache = NCBICache(cache_scratch_path)
scratch_ncbi_cache = NCBICache(scratch_ncbi_cache_path)

taxonomy = scratch_cache.load_taxonomy(taxid)
taxonomy = scratch_ncbi_cache.load_taxonomy(taxid)

assert taxonomy == snapshot(name=f"{taxid}.json")

def test_cache_taxonomy_cache(self, taxid, cache_example_path, empty_cache_path):
taxonomy = get_test_taxonomy(taxid, cache_example_path)
def test_cache_taxonomy_cache(
self,
taxid: int,
scratch_ncbi_cache_path: Path,
empty_cache_path,
):
taxonomy = get_test_taxonomy(taxid, scratch_ncbi_cache_path)

fresh_cache = NCBICache(empty_cache_path)

Expand Down
Loading

0 comments on commit 98cba92

Please sign in to comment.