11
11
import numpy as np
12
12
import pandas as pd
13
13
import scipy
14
+ import tiledb
14
15
from einops import rearrange
15
16
from torch import Tensor
16
17
from torch .utils .data import Dataset
@@ -192,26 +193,38 @@ def get_labels(self) -> np.ndarray:
192
193
193
194
return labels
194
195
196
+ def _idx_to_raw_pair (self , idx : int ) -> Tuple [np .ndarray , np .ndarray ]:
197
+ return self .seqs [idx ], self .labels [idx ]
198
+
199
+ def _idx_to_raw_seq (self , idx : int ) -> np .ndarray :
200
+ return self .seqs [idx ]
201
+
195
202
def __getitem__ (self , idx : int ) -> Union [Tensor , Tuple [Tensor , Tensor ]]:
196
203
# Get sequence and augmentation indices
197
204
seq_idx , augment_idx = _split_overall_idx (idx , (self .n_seqs , self .n_augmented ))
198
205
199
- # Get current sequence and label
200
- seq = self . seqs [ seq_idx ]
201
- label = self .labels [ seq_idx ]
206
+ if self . predict :
207
+ # Get raw sequence
208
+ seq = self ._idx_to_raw_seq ( seq_idx )
202
209
203
- # Augment
204
- seq , label = self .augmenter (seq = seq , label = label , idx = augment_idx )
210
+ # Augment
211
+ seq = self .augmenter (seq = seq , idx = augment_idx )
205
212
206
- # One-hot encode
207
- seq = indices_to_one_hot (seq )
213
+ # One-hot encode
214
+ seq = indices_to_one_hot (seq )
208
215
209
- # If using in prediction, return only the sequence
210
- if self .predict :
211
216
return seq
212
217
213
- # Otherwise, return the sequence/label pair
214
218
else :
219
+ # Get raw sequence-label pair
220
+ seq , label = self ._idx_to_raw_pair (seq_idx )
221
+
222
+ # Augment
223
+ seq , label = self .augmenter (seq = seq , label = label , idx = augment_idx )
224
+
225
+ # One-hot encode
226
+ seq = indices_to_one_hot (seq )
227
+
215
228
# Aggregate label
216
229
if self .label_aggfunc is not None :
217
230
label = rearrange (label , "t (l b) -> t l b" , b = self .bin_size )
@@ -243,6 +256,8 @@ class DFSeqDataset(LabeledSeqDataset):
243
256
they will not be reverse complemented.
244
257
max_seq_shift: Maximum number of bases to shift the sequence for augmentation.
245
258
This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
259
+ seed: Random seed for reproducibility
260
+ augment_mode: "random" or "serial"
246
261
"""
247
262
248
263
def __init__ (
@@ -312,6 +327,8 @@ class AnnDataSeqDataset(LabeledSeqDataset):
312
327
False, they will not be reverse complemented.
313
328
max_seq_shift: Maximum number of bases to shift the sequence for augmentation.
314
329
This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
330
+ seed: Random seed for reproducibility
331
+ augment_mode: "random" or "serial"
315
332
"""
316
333
317
334
def __init__ (
@@ -388,6 +405,8 @@ class BigWigSeqDataset(LabeledSeqDataset):
388
405
min_label_clip: Minimum value for label
389
406
max_label_clip: Maximum value for label
390
407
label_transform_func: Function to transform label values.
408
+ seed: Random seed for reproducibility
409
+ augment_mode: "random" or "serial"
391
410
"""
392
411
393
412
def __init__ (
@@ -445,6 +464,137 @@ def _load_labels(self, bw_files: Union[str, List[str]]) -> None:
445
464
self .labels = read_bigwig (intervals , bw_files , aggfunc = None )
446
465
447
466
467
+ class TileDBSeqDataset (LabeledSeqDataset ):
468
+ """
469
+ LabeledSeqDataset derived class for genomic intervals and TileDB files.
470
+ TileDB files are created by grelu.data.preprocess.write_tiledb.
471
+
472
+ Args:
473
+ intervals: A Pandas dataframe containing genomic intervals
474
+ tdb_path: Path to tileDB.
475
+ samples: A subset of samples to read
476
+ seq_len: Uniform expected length (in base pairs) for output sequences
477
+ end: Which end of the sequence to resize. Supported values are "left", "right"
478
+ and "both".
479
+ rc: If True, sequences will be augmented by reverse complementation. If False,
480
+ they will not be reverse complemented.
481
+ max_seq_shift: Maximum number of bases to shift the sequence for augmentation.
482
+ This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
483
+ max_pair_shift: Maximum number of bases to shift both the sequence and label for
484
+ augmentation. If 0, sequence and label pairs will not be augmented by shifting.
485
+ label_aggfunc: Function to aggregate the labels over bin_size.
486
+ bin_size: Number of bases to aggregate in the label.
487
+ min_label_clip: Minimum value for label
488
+ max_label_clip: Maximum value for label
489
+ label_transform_func: Function to transform label values.
490
+ seed: Random seed for reproducibility
491
+ augment_mode: "random" or "serial"
492
+ """
493
+
494
+ def __init__ (
495
+ self ,
496
+ intervals : pd .DataFrame ,
497
+ tdb_path : str ,
498
+ seq_len : Optional [int ] = None ,
499
+ end : str = "both" ,
500
+ rc : bool = False ,
501
+ max_seq_shift : int = 0 ,
502
+ label_len : Optional [int ] = None ,
503
+ max_pair_shift : int = 0 ,
504
+ label_aggfunc : Optional [Union [str , Callable ]] = np .sum ,
505
+ bin_size : Optional [int ] = None ,
506
+ min_label_clip : Optional [int ] = None ,
507
+ max_label_clip : Optional [int ] = None ,
508
+ label_transform_func : Optional [Union [str , Callable ]] = None ,
509
+ seed : Optional [int ] = None ,
510
+ augment_mode : str = "serial" ,
511
+ ) -> None :
512
+
513
+ # Paths to tileDB
514
+ self .tdb_path = tdb_path
515
+ self ._task_uri = f"{ tdb_path } /tasks"
516
+ self ._chrom_uri = f"{ tdb_path } /chroms"
517
+
518
+ # Open tileDB database
519
+ with tiledb .open (self ._task_uri , "r" ) as fp :
520
+ self .tasks = pd .DataFrame (fp [:])
521
+
522
+ with tiledb .open (self ._chrom_uri , "r" ) as fp :
523
+ self .chroms = pd .DataFrame (fp [:])
524
+
525
+ self ._data_uris = {row .chrom : row .uri for row in self .chroms .itertuples ()}
526
+ self .open ()
527
+
528
+ super ().__init__ (
529
+ seqs = intervals ,
530
+ labels = None ,
531
+ tasks = self .tasks ,
532
+ seq_len = seq_len ,
533
+ genome = None ,
534
+ end = end ,
535
+ rc = rc ,
536
+ max_seq_shift = max_seq_shift ,
537
+ label_len = label_len ,
538
+ max_pair_shift = max_pair_shift ,
539
+ label_aggfunc = label_aggfunc ,
540
+ bin_size = bin_size ,
541
+ min_label_clip = min_label_clip ,
542
+ max_label_clip = max_label_clip ,
543
+ label_transform_func = label_transform_func ,
544
+ seed = seed ,
545
+ augment_mode = augment_mode ,
546
+ )
547
+
548
+ def _load_seqs (self , seqs : pd .DataFrame ) -> None :
549
+ seqs = resize (seqs , seq_len = self .padded_seq_len , end = self .end )
550
+ self .seqs = seqs
551
+
552
+ def open (self ) -> None :
553
+ self .dataset = {
554
+ chrom : tiledb .open (uri , "r" ) for chrom , uri in self ._data_uris .items ()
555
+ }
556
+
557
+ def close (self ) -> None :
558
+ for _ , v in self .data .items ():
559
+ v .close ()
560
+
561
+ def _idx_to_raw_pair (self , idx : int ) -> Tuple [np .ndarray , np .ndarray ]:
562
+ interval = self .seqs .iloc [idx ]
563
+ data = self .dataset [interval .chrom ][:, interval .start : interval .end ]["data" ]
564
+ return data [0 ], data [1 :]
565
+
566
+ def _idx_to_raw_seq (self , idx : int ) -> np .ndarray :
567
+ interval = self .seqs .iloc [idx ]
568
+ return self .dataset [interval .chrom ][0 , interval .start : interval .end ]["data" ]
569
+
570
+ def _load_labels (self , labels : np .ndarray ) -> None :
571
+ pass
572
+
573
+ def get_labels (self , intervals ) -> np .ndarray :
574
+ """
575
+ Return the labels as a numpy array of shape (B, T, L). This does not
576
+ account for data augmentation.
577
+ """
578
+ labels = []
579
+ for interval in intervals .iterrows ():
580
+ labels .append (self .data [interval .chrom ][interval .start : interval .end , :])
581
+ labels = np .vstack (labels )
582
+
583
+ # Aggregate label
584
+ if self .label_aggfunc is not None :
585
+ labels = rearrange (
586
+ labels ,
587
+ "batch task (length bin_size) -> batch task length bin_size" ,
588
+ bin_size = self .bin_size ,
589
+ )
590
+ labels = self .label_aggfunc (labels , axis = - 1 )
591
+
592
+ # Transform label
593
+ labels = self .label_transform (labels )
594
+
595
+ return labels
596
+
597
+
448
598
class SeqDataset (Dataset ):
449
599
"""
450
600
Dataset to cycle through unlabeled sequences for inference. All sequences
@@ -462,6 +612,8 @@ class SeqDataset(Dataset):
462
612
max_seq_shift: Maximum number of bases to shift the sequence for augmentation.
463
613
This is normally a small value (< 10). If 0, sequences will not be
464
614
augmented by shifting.
615
+ seed: Random seed for reproducibility
616
+ augment_mode: "random" or "serial"
465
617
"""
466
618
467
619
def __init__ (
@@ -542,6 +694,8 @@ class VariantDataset(Dataset):
542
694
protect: A list of positions to protect from mutation.
543
695
n_mutated_seqs: Number of mutated sequences to generate from each input
544
696
sequence for data augmentation.
697
+ seed: Random seed for reproducibility
698
+ augment_mode: "random" or "serial"
545
699
"""
546
700
547
701
def __init__ (
0 commit comments