Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read BGEN files #36

Merged
merged 2 commits into from
Jun 9, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docker/prototype/environment.ci.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
channels:
- conda-forge
- defaults
dependencies:
- pytest
- pytest-datadir
- pylint
- hypothesis
- black
Expand Down
1 change: 1 addition & 0 deletions docker/prototype/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ dependencies:
- python-dotenv==0.11.0
- xlrd
- dask_labextension
- pybgen
1 change: 1 addition & 0 deletions notebooks/platform/xarray/lib/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from . import io
from .io.core import (
read_bgen,
read_plink,
write_zarr
)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/platform/xarray/lib/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def resolve_backend(self, fn: Callable, *args, **kwargs) -> Backend:

backend = next((b for b in self.backends.values() if is_compatible(b)), None)
if backend is None:
raise ValueError(f'No backend found for function "{fn.__name__}" (domain = "{self.domain}")')
raise ValueError(f'No compatible backend found for function "{fn.__name__}" (domain = "{self.domain}"). Check you have installed the required packages.')
return backend, kwargs

def dispatch(self, fn: Callable, *args, **kwargs):
Expand Down
5 changes: 5 additions & 0 deletions notebooks/platform/xarray/lib/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
DOMAIN = __name__.split('.')[-1]

try:
from . import pybgen_backend
except ImportError:
pass

try:
from . import pysnptools_backend
except ImportError:
Expand Down
7 changes: 7 additions & 0 deletions notebooks/platform/xarray/lib/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

DOMAIN = Domain(DOMAIN)
PLINK_DOMAIN = DOMAIN.append('plink')
BGEN_DOMAIN = DOMAIN.append('bgen')


@register_function(PLINK_DOMAIN, append=False)
Expand All @@ -15,6 +16,12 @@ def read_plink(path, backend=None, **kwargs) -> Dataset:
pass


@register_function(BGEN_DOMAIN, append=False)
def read_bgen(path, backend=None, **kwargs) -> Dataset:
"""Import BGEN dataset"""
pass


def write_zarr(ds: Dataset, path: PathType, rechunk: bool=False, **kwargs) -> "ZarrStore":
"""Write dataset to zarr

Expand Down
91 changes: 91 additions & 0 deletions notebooks/platform/xarray/lib/io/pybgen_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import dask.array as da
import numpy as np
from pybgen import PyBGEN
from pybgen.parallel import ParallelPyBGEN
import xarray as xr

from .. import core
from ..typing import PathType
from ..compat import Requirement
from ..dispatch import ClassBackend, register_backend
from .core import BGEN_DOMAIN


def _array_name(f, path):
return f.__qualname__ + ':' + str(path)


class BgenReader(object):

def __init__(self, path, dtype=np.float32):
self.path = str(path) # pybgen needs string paths

# Use ParallelPyBGEN only to get all the variant seek positions from the BGEN index.
# No parallel IO happens here.
with ParallelPyBGEN(self.path, probs_only=False) as bgen:
bgen._get_all_seeks()
self._seeks = bgen._seeks
n_variants = bgen.nb_variants
n_samples = bgen.nb_samples

self.shape = (n_variants, n_samples)
self.dtype = dtype
self.ndim = 2

self.sample_id = bgen.samples
# This may need chunking for large numbers of variants
variants = list(bgen.iter_variant_info())
tomwhite marked this conversation as resolved.
Show resolved Hide resolved
self.variant_id = [v.name for v in variants]
self.contig = [v.chrom for v in variants]
self.pos = [v.pos for v in variants]
self.a1 = [v.a1 for v in variants]
self.a2 = [v.a2 for v in variants]

def __getitem__(self, idx):
if not isinstance(idx, tuple):
raise IndexError(f'Indexer must be tuple (received {type(idx)})')
if len(idx) != self.ndim:
raise IndexError(f'Indexer must be two-item tuple (received {len(idx)} slices)')

# Restrict to seeks for this chunk
seeks_for_chunk = self._seeks[idx[0]]
if len(seeks_for_chunk) == 0:
return np.empty((0, 0), dtype=self.dtype)
with PyBGEN(self.path, probs_only=False) as bgen:
p = [probs for (_, probs) in bgen._iter_seeks(seeks_for_chunk)]
eric-czech marked this conversation as resolved.
Show resolved Hide resolved
eric-czech marked this conversation as resolved.
Show resolved Hide resolved
return np.stack(p)[:,idx[1]]


@register_backend(BGEN_DOMAIN)
class PyBgenBackend(ClassBackend):

id = 'pybgen'

def read_bgen(self, path: PathType, chunks='auto', lock=False):

bgen_reader = BgenReader(path)

vars = {
"sample_id": xr.DataArray(np.array(bgen_reader.sample_id), dims=["sample"]),
"variant_id": xr.DataArray(np.array(bgen_reader.variant_id), dims=["variant"]),
"contig": xr.DataArray(np.array(bgen_reader.contig), dims=["variant"]),
"pos": xr.DataArray(np.array(bgen_reader.pos), dims=["variant"]),
"a1": xr.DataArray(np.array(bgen_reader.a1), dims=["variant"]),
"a2": xr.DataArray(np.array(bgen_reader.a2), dims=["variant"]),
}

arr = da.from_array(
bgen_reader,
chunks=chunks,
lock=lock,
eric-czech marked this conversation as resolved.
Show resolved Hide resolved
asarray=False,
name=_array_name(self.read_bgen, path))

# pylint: disable=no-member
eric-czech marked this conversation as resolved.
Show resolved Hide resolved
ds = core.create_genotype_dosage_dataset(arr)
ds = ds.assign(vars)
return ds

@property
def requirements(self):
return [Requirement('pybgen'), Requirement('dask')]
Empty file.
Binary file not shown.
Binary file not shown.
22 changes: 22 additions & 0 deletions notebooks/platform/xarray/lib/tests/io/test_pybgen_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from lib import api
import numpy.testing as npt

# Note that shared_datadir is set up by pytest-datadir
def test_load_bgen(shared_datadir):
path = (shared_datadir / "example.bgen")
ds = api.read_bgen(path, chunks=(100, 500))

assert "sample" in ds.dims
assert "variant" in ds.dims

assert "data" in ds.variables
assert "sample_id" in ds.variables
assert "variant_id" in ds.variables
assert "contig" in ds.variables
assert "pos" in ds.variables
assert "a1" in ds.variables
assert "a2" in ds.variables

# check some of the data (in different chunks)
npt.assert_almost_equal(ds.data.values[1][0], 1.99, decimal=2)
npt.assert_almost_equal(ds.data.values[100][0], 0.16, decimal=2)