Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue a warning if the number of alt alleles exceeds the maximum specified #620

Merged
merged 3 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sgkit/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ def zarrs_to_dataset(
ds[variable_name] = ds[variable_name].astype(f"S{max_length}")
del ds.attrs[attr]

if "max_alt_alleles_seen" in datasets[0].attrs:
ds.attrs["max_alt_alleles_seen"] = max(
ds.attrs["max_alt_alleles_seen"] for ds in datasets
)

return ds


Expand Down
3 changes: 2 additions & 1 deletion sgkit/io/vcf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
try:
from ..utils import zarrs_to_dataset
from .vcf_partition import partition_into_regions
from .vcf_reader import vcf_to_zarr, vcf_to_zarrs
from .vcf_reader import MaxAltAllelesExceededWarning, vcf_to_zarr, vcf_to_zarrs

__all__ = [
"MaxAltAllelesExceededWarning",
"partition_into_regions",
"vcf_to_zarr",
"vcf_to_zarrs",
Expand Down
35 changes: 32 additions & 3 deletions sgkit/io/vcf/vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from numcodecs import PackBits

from sgkit import variables
from sgkit.io.dataset import load_dataset
from sgkit.io.utils import zarrs_to_dataset
from sgkit.io.vcf import partition_into_regions
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename
Expand All @@ -50,6 +51,12 @@
DEFAULT_COMPRESSOR = None


class MaxAltAllelesExceededWarning(UserWarning):
"""Warning when the number of alt alleles exceeds the maximum specified."""

pass


@contextmanager
def open_vcf(path: PathType) -> Iterator[VCF]:
"""A context manager for opening a VCF file."""
Expand Down Expand Up @@ -154,7 +161,7 @@ def for_field(
) -> "VcfFieldHandler":
if field == "FORMAT/GT":
return GenotypeFieldHandler(
vcf, chunk_length, ploidy, mixed_ploidy, truncate_calls
vcf, chunk_length, ploidy, mixed_ploidy, truncate_calls, max_alt_alleles
)
category = field.split("/")[0]
vcf_field_defs = _get_vcf_field_defs(vcf, category)
Expand Down Expand Up @@ -279,13 +286,15 @@ def __init__(
ploidy: int,
mixed_ploidy: bool,
truncate_calls: bool,
max_alt_alleles: int,
) -> None:
n_sample = len(vcf.samples)
self.call_genotype = np.empty((chunk_length, n_sample, ploidy), dtype="i1")
self.call_genotype_phased = np.empty((chunk_length, n_sample), dtype=bool)
self.ploidy = ploidy
self.mixed_ploidy = mixed_ploidy
self.truncate_calls = truncate_calls
self.max_alt_alleles = max_alt_alleles

def add_variant(self, i: int, variant: Any) -> None:
fill = -2 if self.mixed_ploidy else -1
Expand All @@ -298,6 +307,10 @@ def add_variant(self, i: int, variant: Any) -> None:
self.call_genotype[i, ..., 0:n] = gt[..., 0:n]
self.call_genotype[i, ..., n:] = fill
self.call_genotype_phased[i] = gt[..., -1]

# set any calls that exceed maximum number of alt alleles as missing
self.call_genotype[i][self.call_genotype[i] > self.max_alt_alleles] = -1

else:
self.call_genotype[i] = fill
self.call_genotype_phased[i] = 0
Expand Down Expand Up @@ -362,6 +375,7 @@ def vcf_to_zarr_sequential(
# Remember max lengths of variable-length strings
max_variant_id_length = 0
max_variant_allele_length = 0
max_alt_alleles_seen = 0

# Iterate through variants in batches of chunk_length

Expand Down Expand Up @@ -413,6 +427,7 @@ def vcf_to_zarr_sequential(
variant_position[i] = variant.POS

alleles = [variant.REF] + variant.ALT
max_alt_alleles_seen = max(max_alt_alleles_seen, len(variant.ALT))
if len(alleles) > n_allele:
alleles = alleles[:n_allele]
elif len(alleles) < n_allele:
Expand Down Expand Up @@ -457,6 +472,7 @@ def vcf_to_zarr_sequential(
if add_str_max_length_attrs:
ds.attrs["max_length_variant_id"] = max_variant_id_length
ds.attrs["max_length_variant_allele"] = max_variant_allele_length
ds.attrs["max_alt_alleles_seen"] = max_alt_alleles_seen

if first_variants_chunk:
# Enforce uniform chunks in the variants dimension
Expand Down Expand Up @@ -605,7 +621,9 @@ def vcf_to_zarrs(
specified ploidy will raise an exception.
max_alt_alleles
The (maximum) number of alternate alleles in the VCF file. Any records with more than
this number of alternate alleles will have the extra alleles dropped.
this number of alternate alleles will have the extra alleles dropped (the `variant_allele`
variable will be truncated). Any call genotype fields with the extra alleles will
be changed to the missing-allele sentinel value of -1.
fields
Extra fields to extract data for. A list of strings, with ``INFO`` or ``FORMAT`` prefixes.
Wildcards are permitted too, for example: ``["INFO/*", "FORMAT/DP"]``.
Expand Down Expand Up @@ -772,7 +790,9 @@ def vcf_to_zarr(
specified ploidy will raise an exception.
max_alt_alleles
The (maximum) number of alternate alleles in the VCF file. Any records with more than
this number of alternate alleles will have the extra alleles dropped.
this number of alternate alleles will have the extra alleles dropped (the `variant_allele`
variable will be truncated). Any call genotype fields with the extra alleles will
be changed to the missing-allele sentinel value of -1.
fields
Extra fields to extract data for. A list of strings, with ``INFO`` or ``FORMAT`` prefixes.
Wildcards are permitted too, for example: ``["INFO/*", "FORMAT/DP"]``.
Expand Down Expand Up @@ -839,6 +859,15 @@ def vcf_to_zarr(
field_defs=field_defs,
)

# Issue a warning if max_alt_alleles caused data to be dropped
ds = load_dataset(output)
max_alt_alleles_seen = ds.attrs["max_alt_alleles_seen"]
if max_alt_alleles_seen > max_alt_alleles:
jeromekelleher marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
f"Some alternate alleles were dropped, since actual max value {max_alt_alleles_seen} exceeded max_alt_alleles setting of {max_alt_alleles}.",
MaxAltAllelesExceededWarning,
)


def count_variants(path: PathType, region: Optional[str] = None) -> int:
"""Count the number of variants in a VCF file."""
Expand Down
86 changes: 68 additions & 18 deletions sgkit/tests/io/vcf/test_vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from numpy.testing import assert_allclose, assert_array_equal

from sgkit import load_dataset
from sgkit.io.vcf import partition_into_regions, vcf_to_zarr
from sgkit.io.vcf import (
MaxAltAllelesExceededWarning,
partition_into_regions,
vcf_to_zarr,
)

from .utils import path_for_test

Expand Down Expand Up @@ -98,30 +102,41 @@ def test_vcf_to_zarr__max_alt_alleles(shared_datadir, is_path, tmp_path):
path = path_for_test(shared_datadir, "sample.vcf.gz", is_path)
output = tmp_path.joinpath("vcf.zarr").as_posix()

vcf_to_zarr(path, output, chunk_length=5, chunk_width=2, max_alt_alleles=1)
ds = xr.open_zarr(output)
with pytest.warns(MaxAltAllelesExceededWarning):
max_alt_alleles = 1
vcf_to_zarr(
path, output, chunk_length=5, chunk_width=2, max_alt_alleles=max_alt_alleles
)
ds = xr.open_zarr(output)

# extra alt alleles are silently dropped
assert_array_equal(
ds["variant_allele"],
[
["A", "C"],
["A", "G"],
["G", "A"],
["T", "A"],
["A", "G"],
["T", ""],
["G", "GA"],
["T", ""],
["AC", "A"],
],
)
# extra alt alleles are dropped
assert_array_equal(
ds["variant_allele"],
[
["A", "C"],
["A", "G"],
["G", "A"],
["T", "A"],
["A", "G"],
["T", ""],
["G", "GA"],
["T", ""],
["AC", "A"],
],
)

# genotype calls are truncated
assert np.all(ds["call_genotype"].values <= max_alt_alleles)

# the maximum number of alt alleles actually seen is stored as an attribute
assert ds.attrs["max_alt_alleles_seen"] == 3


@pytest.mark.parametrize(
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__large_vcf(shared_datadir, is_path, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
output = tmp_path.joinpath("vcf.zarr").as_posix()
Expand Down Expand Up @@ -159,6 +174,7 @@ def test_vcf_to_zarr__plain_vcf_with_no_index(shared_datadir, tmp_path):
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__mutable_mapping(shared_datadir, is_path):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
output: MutableMapping[str, bytes] = {}
Expand Down Expand Up @@ -217,6 +233,7 @@ def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path):
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
Expand Down Expand Up @@ -266,6 +283,7 @@ def test_vcf_to_zarr__empty_region(shared_datadir, is_path, tmp_path):
"is_path",
[False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__parallel_temp_chunk_length(shared_datadir, is_path, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
Expand Down Expand Up @@ -354,6 +372,7 @@ def test_vcf_to_zarr__parallel_partitioned_by_size(shared_datadir, is_path, tmp_
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path):
paths = [
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
Expand Down Expand Up @@ -381,6 +400,7 @@ def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path):
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path):
paths = [
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
Expand Down Expand Up @@ -410,6 +430,7 @@ def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path):
"is_path",
[True, False],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__multiple_partitioned_by_size(shared_datadir, is_path, tmp_path):
paths = [
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
Expand Down Expand Up @@ -456,6 +477,31 @@ def test_vcf_to_zarr__mutiple_partitioned_invalid_regions(
vcf_to_zarr(paths, output, regions=regions, chunk_length=5_000)


@pytest.mark.parametrize(
"is_path",
[True, False],
)
def test_vcf_to_zarr__multiple_max_alt_alleles(shared_datadir, is_path, tmp_path):
paths = [
path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path),
path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path),
]
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()

with pytest.warns(MaxAltAllelesExceededWarning):
vcf_to_zarr(
paths,
output,
target_part_size="40KB",
chunk_length=5_000,
max_alt_alleles=1,
)
ds = xr.open_zarr(output)

# the maximum number of alt alleles actually seen is stored as an attribute
assert ds.attrs["max_alt_alleles_seen"] == 7


@pytest.mark.parametrize(
"ploidy,mixed_ploidy,truncate_calls,regions",
[
Expand Down Expand Up @@ -647,6 +693,7 @@ def test_vcf_to_zarr__fields(shared_datadir, tmp_path):
assert ds["call_DP"].attrs["comment"] == "Read Depth"


@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__parallel_with_fields(shared_datadir, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz")
output = tmp_path.joinpath("vcf.zarr").as_posix()
Expand Down Expand Up @@ -703,6 +750,7 @@ def test_vcf_to_zarr__field_defs(shared_datadir, tmp_path):
assert "comment" not in ds["variant_DP"].attrs


@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path):
path = path_for_test(shared_datadir, "sample.vcf.gz")
output = tmp_path.joinpath("vcf.zarr").as_posix()
Expand Down Expand Up @@ -736,6 +784,7 @@ def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path):
)


@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__field_number_R(shared_datadir, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz")
output = tmp_path.joinpath("vcf.zarr").as_posix()
Expand Down Expand Up @@ -768,6 +817,7 @@ def test_vcf_to_zarr__field_number_R(shared_datadir, tmp_path):
)


@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_vcf_to_zarr__field_number_G(shared_datadir, tmp_path):
path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz")
output = tmp_path.joinpath("vcf.zarr").as_posix()
Expand Down
22 changes: 17 additions & 5 deletions sgkit/tests/io/vcf/test_vcf_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_default_fields(shared_datadir, tmpdir):
sg_vcfzarr_path = create_sg_vcfzarr(shared_datadir, tmpdir)
sg_ds = sg.load_dataset(str(sg_vcfzarr_path))
sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel
del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel

assert_identical(allel_ds, sg_ds)

Expand Down Expand Up @@ -107,21 +108,29 @@ def test_DP_field(shared_datadir, tmpdir):
)
sg_ds = sg.load_dataset(str(sg_vcfzarr_path))
sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel
del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel

assert_identical(allel_ds, sg_ds)


@pytest.mark.parametrize(
"vcf_file,allel_exclude_fields,sgkit_exclude_fields",
"vcf_file,allel_exclude_fields,sgkit_exclude_fields,max_alt_alleles",
[
("sample.vcf.gz", None, None),
("mixed.vcf.gz", None, None),
("sample.vcf.gz", None, None, 3),
("mixed.vcf.gz", None, None, 3),
# exclude PL since it has Number=G, which is not yet supported
("CEUTrio.20.21.gatk3.4.g.vcf.bgz", ["calldata/PL"], ["FORMAT/PL"]),
# increase max_alt_alleles since scikit-allel does not truncate genotype calls
("CEUTrio.20.21.gatk3.4.g.vcf.bgz", ["calldata/PL"], ["FORMAT/PL"], 7),
],
)
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
def test_all_fields(
shared_datadir, tmpdir, vcf_file, allel_exclude_fields, sgkit_exclude_fields
shared_datadir,
tmpdir,
vcf_file,
allel_exclude_fields,
sgkit_exclude_fields,
max_alt_alleles,
):
# change scikit-allel type defaults back to the VCF default
types = {
Expand All @@ -137,6 +146,7 @@ def test_all_fields(
fields=["*"],
exclude_fields=allel_exclude_fields,
types=types,
alt_number=max_alt_alleles,
)

field_defs = {
Expand All @@ -156,9 +166,11 @@ def test_all_fields(
exclude_fields=sgkit_exclude_fields,
field_defs=field_defs,
truncate_calls=True,
max_alt_alleles=max_alt_alleles,
)
sg_ds = sg.load_dataset(str(sg_vcfzarr_path))
sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel
del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel

# scikit-allel only records contigs for which there are actual variants,
# whereas sgkit records contigs from the header
Expand Down