forked from bowang-lab/scGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_collator.py
202 lines (177 loc) · 7.32 KB
/
data_collator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import torch
import numpy as np
from .preprocess import binning
@dataclass
class DataCollator:
"""
Data collator for the mask value learning task. It pads the sequences to
the maximum length in the batch and masks the gene expression values.
Args:
do_padding (:obj:`bool`): whether to pad the sequences to the max length.
pad_token_id (:obj:`int`, optional): the token id to use for padding.
This is required if do_padding is True.
pad_value (:obj:`int`): the value to use for padding the expression
values to the max length.
do_mlm (:obj:`bool`): whether to do masking with MLM.
do_binning (:obj:`bool`): whether to bin the expression values.
mlm_probability (:obj:`float`): the probability of masking with MLM.
mask_value (:obj:`int`): the value to fill at the expression postions
that are masked.
max_length (:obj:`int`, optional): the maximum length of the sequences.
This is required if do_padding is True.
sampling (:obj:`bool`): whether to do sampling instead of truncation if
length > max_length.
keep_first_n_tokens (:obj:`int`): the number of tokens in the beginning
of the sequence to keep unchanged from sampling. This is useful when
special tokens have been added to the beginning of the sequence.
Default to 1.
"""
do_padding: bool = True
pad_token_id: Optional[int] = None
pad_value: int = 0
do_mlm: bool = True
do_binning: bool = True
mlm_probability: float = 0.15
mask_value: int = -1
max_length: Optional[int] = None
sampling: bool = True
keep_first_n_tokens: int = 1
def __post_init__(self):
if self.do_padding:
if self.pad_token_id is None:
raise ValueError("`pad_token_id` is required if `do_padding`.")
if self.max_length is None:
raise ValueError("`max_length` is required if `do_padding`.")
if self.mlm_probability <= 0 or self.mlm_probability >= 1:
raise ValueError("`mlm_probability` must be between 0 and 1.")
if self.keep_first_n_tokens < 0 or self.keep_first_n_tokens > self.max_length:
raise ValueError(
"`keep_first_n_tokens` must be between 0 and `max_length` "
f"({self.max_length})."
)
def __call__(
self, examples: List[Dict[str, torch.Tensor]]
) -> Dict[str, torch.Tensor]:
"""
Each example is like:
{'id': tensor(184117),
'genes': tensor([36572, 17868, ..., 17072]),
'expressions': tensor([ 0., 2., ..., 18.])}
"""
if not isinstance(examples[0], Mapping):
return NotImplementedError
device = examples[0]["genes"].device
max_ori_len = max(len(example["genes"]) for example in examples)
_max_length = self.max_length if max_ori_len >= self.max_length else max_ori_len
# pad and truncate
padded_genes = []
padded_expressions = []
for i in range(len(examples)):
genes = examples[i]["genes"]
expressions = examples[i]["expressions"]
if self.do_binning:
expressions[self.keep_first_n_tokens :] = binning(
row=expressions[self.keep_first_n_tokens :],
n_bins=51,
)
genes, expressions = self._sample_or_truncate_plus_pad(
genes, expressions, _max_length
) # torch tensors of length _max_length
padded_genes.append(genes)
padded_expressions.append(expressions)
padded_genes = torch.stack(padded_genes, dim=0).to(device)
padded_expressions = torch.stack(padded_expressions, dim=0).to(device)
data_dict = {
"gene": padded_genes,
"expr": padded_expressions,
}
# mask
if self.do_mlm:
masked_expressions = self._mask(padded_expressions)
else:
masked_expressions = padded_expressions
data_dict["masked_expr"] = masked_expressions
return data_dict
def _mask(self, expressions: torch.Tensor) -> torch.Tensor:
"""
Mask the expression values with MLM.
"""
device = expressions.device
shape = expressions.shape
probability_matrix = torch.full(shape, self.mlm_probability)
# set padded postion probability to 0
probability_matrix[expressions.eq(self.pad_value)] = 0
if self.keep_first_n_tokens > 0:
probability_matrix[:, : self.keep_first_n_tokens] = 0
mask = torch.bernoulli(probability_matrix).bool()
mask = mask.to(device)
masked_expressions = expressions.masked_fill(mask, self.mask_value)
return masked_expressions
def _sample_or_truncate_plus_pad(
self,
genes: torch.LongTensor,
expressions: torch.Tensor,
max_length: int,
) -> Tuple[torch.LongTensor, torch.Tensor]:
assert len(genes) == len(expressions)
if len(genes) == max_length:
return genes, expressions
if len(genes) > max_length: # sample or truncate
if self.sampling:
return self._sample(genes, expressions, max_length)
else:
return genes[:max_length], expressions[:max_length]
else: # pad
return self._pad(genes, expressions, max_length)
def _sample(
self,
genes: torch.LongTensor,
expressions: torch.Tensor,
max_length: int,
) -> Tuple[torch.LongTensor, torch.Tensor]:
# NOTE: the fastest way to sample in torch has been benchmarked here
# https://discuss.pytorch.org/t/torch-equivalent-of-numpy-random-choice/16146/19
# it shows the randperm on gpu is the fastest.
# NOTE: also, the current implementation permute the orders of the genes
# and expressions, although it is probably a nice argmentation.
device = genes.device
if self.keep_first_n_tokens == 0:
indices = torch.randperm(len(genes), device=device)[:max_length]
return genes[indices], expressions[indices]
# keep the first n tokens unchanged
_n = self.keep_first_n_tokens
indices = torch.randperm(len(genes) - _n, device=device)[: max_length - _n]
indices = torch.cat([torch.arange(_n), indices + _n], dim=0)
return genes[indices], expressions[indices]
def _pad(
self,
genes: torch.LongTensor,
expressions: torch.Tensor,
max_length: int,
):
device = genes.device
genes = torch.cat(
[
genes,
torch.full(
(max_length - len(genes),),
self.pad_token_id,
dtype=genes.dtype,
device=device,
),
]
)
expressions = torch.cat(
[
expressions,
torch.full(
(max_length - len(expressions),),
self.pad_value,
dtype=expressions.dtype,
device=device,
),
]
)
return genes, expressions