Skip to content

Commit f45908c

Browse files
committed
added tiledb dataset
1 parent baacaca commit f45908c

File tree

1 file changed

+164
-10
lines changed

1 file changed

+164
-10
lines changed

src/grelu/data/dataset.py

Lines changed: 164 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
import pandas as pd
1313
import scipy
14+
import tiledb
1415
from einops import rearrange
1516
from torch import Tensor
1617
from torch.utils.data import Dataset
@@ -192,26 +193,38 @@ def get_labels(self) -> np.ndarray:
192193

193194
return labels
194195

196+
def _idx_to_raw_pair(self, idx: int) -> Tuple[np.ndarray, np.ndarray]:
197+
return self.seqs[idx], self.labels[idx]
198+
199+
def _idx_to_raw_seq(self, idx: int) -> np.ndarray:
200+
return self.seqs[idx]
201+
195202
def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, Tensor]]:
196203
# Get sequence and augmentation indices
197204
seq_idx, augment_idx = _split_overall_idx(idx, (self.n_seqs, self.n_augmented))
198205

199-
# Get current sequence and label
200-
seq = self.seqs[seq_idx]
201-
label = self.labels[seq_idx]
206+
if self.predict:
207+
# Get raw sequence
208+
seq = self._idx_to_raw_seq(seq_idx)
202209

203-
# Augment
204-
seq, label = self.augmenter(seq=seq, label=label, idx=augment_idx)
210+
# Augment
211+
seq = self.augmenter(seq=seq, idx=augment_idx)
205212

206-
# One-hot encode
207-
seq = indices_to_one_hot(seq)
213+
# One-hot encode
214+
seq = indices_to_one_hot(seq)
208215

209-
# If using in prediction, return only the sequence
210-
if self.predict:
211216
return seq
212217

213-
# Otherwise, return the sequence/label pair
214218
else:
219+
# Get raw sequence-label pair
220+
seq, label = self._idx_to_raw_pair(seq_idx)
221+
222+
# Augment
223+
seq, label = self.augmenter(seq=seq, label=label, idx=augment_idx)
224+
225+
# One-hot encode
226+
seq = indices_to_one_hot(seq)
227+
215228
# Aggregate label
216229
if self.label_aggfunc is not None:
217230
label = rearrange(label, "t (l b) -> t l b", b=self.bin_size)
@@ -243,6 +256,8 @@ class DFSeqDataset(LabeledSeqDataset):
243256
they will not be reverse complemented.
244257
max_seq_shift: Maximum number of bases to shift the sequence for augmentation.
245258
This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
259+
seed: Random seed for reproducibility
260+
augment_mode: "random" or "serial"
246261
"""
247262

248263
def __init__(
@@ -312,6 +327,8 @@ class AnnDataSeqDataset(LabeledSeqDataset):
312327
False, they will not be reverse complemented.
313328
max_seq_shift: Maximum number of bases to shift the sequence for augmentation.
314329
This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
330+
seed: Random seed for reproducibility
331+
augment_mode: "random" or "serial"
315332
"""
316333

317334
def __init__(
@@ -388,6 +405,8 @@ class BigWigSeqDataset(LabeledSeqDataset):
388405
min_label_clip: Minimum value for label
389406
max_label_clip: Maximum value for label
390407
label_transform_func: Function to transform label values.
408+
seed: Random seed for reproducibility
409+
augment_mode: "random" or "serial"
391410
"""
392411

393412
def __init__(
@@ -445,6 +464,137 @@ def _load_labels(self, bw_files: Union[str, List[str]]) -> None:
445464
self.labels = read_bigwig(intervals, bw_files, aggfunc=None)
446465

447466

467+
class TileDBSeqDataset(LabeledSeqDataset):
468+
"""
469+
LabeledSeqDataset derived class for genomic intervals and TileDB files.
470+
TileDB files are created by grelu.data.preprocess.write_tiledb.
471+
472+
Args:
473+
intervals: A Pandas dataframe containing genomic intervals
474+
tdb_path: Path to tileDB.
475+
samples: A subset of samples to read
476+
seq_len: Uniform expected length (in base pairs) for output sequences
477+
end: Which end of the sequence to resize. Supported values are "left", "right"
478+
and "both".
479+
rc: If True, sequences will be augmented by reverse complementation. If False,
480+
they will not be reverse complemented.
481+
max_seq_shift: Maximum number of bases to shift the sequence for augmentation.
482+
This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
483+
max_pair_shift: Maximum number of bases to shift both the sequence and label for
484+
augmentation. If 0, sequence and label pairs will not be augmented by shifting.
485+
label_aggfunc: Function to aggregate the labels over bin_size.
486+
bin_size: Number of bases to aggregate in the label.
487+
min_label_clip: Minimum value for label
488+
max_label_clip: Maximum value for label
489+
label_transform_func: Function to transform label values.
490+
seed: Random seed for reproducibility
491+
augment_mode: "random" or "serial"
492+
"""
493+
494+
def __init__(
495+
self,
496+
intervals: pd.DataFrame,
497+
tdb_path: str,
498+
seq_len: Optional[int] = None,
499+
end: str = "both",
500+
rc: bool = False,
501+
max_seq_shift: int = 0,
502+
label_len: Optional[int] = None,
503+
max_pair_shift: int = 0,
504+
label_aggfunc: Optional[Union[str, Callable]] = np.sum,
505+
bin_size: Optional[int] = None,
506+
min_label_clip: Optional[int] = None,
507+
max_label_clip: Optional[int] = None,
508+
label_transform_func: Optional[Union[str, Callable]] = None,
509+
seed: Optional[int] = None,
510+
augment_mode: str = "serial",
511+
) -> None:
512+
513+
# Paths to tileDB
514+
self.tdb_path = tdb_path
515+
self._task_uri = f"{tdb_path}/tasks"
516+
self._chrom_uri = f"{tdb_path}/chroms"
517+
518+
# Open tileDB database
519+
with tiledb.open(self._task_uri, "r") as fp:
520+
self.tasks = pd.DataFrame(fp[:])
521+
522+
with tiledb.open(self._chrom_uri, "r") as fp:
523+
self.chroms = pd.DataFrame(fp[:])
524+
525+
self._data_uris = {row.chrom: row.uri for row in self.chroms.itertuples()}
526+
self.open()
527+
528+
super().__init__(
529+
seqs=intervals,
530+
labels=None,
531+
tasks=self.tasks,
532+
seq_len=seq_len,
533+
genome=None,
534+
end=end,
535+
rc=rc,
536+
max_seq_shift=max_seq_shift,
537+
label_len=label_len,
538+
max_pair_shift=max_pair_shift,
539+
label_aggfunc=label_aggfunc,
540+
bin_size=bin_size,
541+
min_label_clip=min_label_clip,
542+
max_label_clip=max_label_clip,
543+
label_transform_func=label_transform_func,
544+
seed=seed,
545+
augment_mode=augment_mode,
546+
)
547+
548+
def _load_seqs(self, seqs: pd.DataFrame) -> None:
549+
seqs = resize(seqs, seq_len=self.padded_seq_len, end=self.end)
550+
self.seqs = seqs
551+
552+
def open(self) -> None:
553+
self.dataset = {
554+
chrom: tiledb.open(uri, "r") for chrom, uri in self._data_uris.items()
555+
}
556+
557+
def close(self) -> None:
558+
for _, v in self.data.items():
559+
v.close()
560+
561+
def _idx_to_raw_pair(self, idx: int) -> Tuple[np.ndarray, np.ndarray]:
562+
interval = self.seqs.iloc[idx]
563+
data = self.dataset[interval.chrom][:, interval.start : interval.end]["data"]
564+
return data[0], data[1:]
565+
566+
def _idx_to_raw_seq(self, idx: int) -> np.ndarray:
567+
interval = self.seqs.iloc[idx]
568+
return self.dataset[interval.chrom][0, interval.start : interval.end]["data"]
569+
570+
def _load_labels(self, labels: np.ndarray) -> None:
571+
pass
572+
573+
def get_labels(self, intervals) -> np.ndarray:
574+
"""
575+
Return the labels as a numpy array of shape (B, T, L). This does not
576+
account for data augmentation.
577+
"""
578+
labels = []
579+
for interval in intervals.iterrows():
580+
labels.append(self.data[interval.chrom][interval.start : interval.end, :])
581+
labels = np.vstack(labels)
582+
583+
# Aggregate label
584+
if self.label_aggfunc is not None:
585+
labels = rearrange(
586+
labels,
587+
"batch task (length bin_size) -> batch task length bin_size",
588+
bin_size=self.bin_size,
589+
)
590+
labels = self.label_aggfunc(labels, axis=-1)
591+
592+
# Transform label
593+
labels = self.label_transform(labels)
594+
595+
return labels
596+
597+
448598
class SeqDataset(Dataset):
449599
"""
450600
Dataset to cycle through unlabeled sequences for inference. All sequences
@@ -462,6 +612,8 @@ class SeqDataset(Dataset):
462612
max_seq_shift: Maximum number of bases to shift the sequence for augmentation.
463613
This is normally a small value (< 10). If 0, sequences will not be
464614
augmented by shifting.
615+
seed: Random seed for reproducibility
616+
augment_mode: "random" or "serial"
465617
"""
466618

467619
def __init__(
@@ -542,6 +694,8 @@ class VariantDataset(Dataset):
542694
protect: A list of positions to protect from mutation.
543695
n_mutated_seqs: Number of mutated sequences to generate from each input
544696
sequence for data augmentation.
697+
seed: Random seed for reproducibility
698+
augment_mode: "random" or "serial"
545699
"""
546700

547701
def __init__(

0 commit comments

Comments
 (0)