-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcfg_predictor.py
72 lines (47 loc) · 2 KB
/
cfg_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import ml_collections
from chioso.modules import LinearPredictor
def get_config():
config = ml_collections.ConfigDict()
config.name = "celltype_linear_model"
config.dataset = ml_collections.ConfigDict()
# Training data location. This is the output of chioso.pp-ref
config.dataset.path = "ref_data/tome.ds"
# Output file name
config.dataset.outname = "ref_embedding.ds"
# We randomly mask-off certain portion of the input data during training, which help increases
# the model robustness.
config.dataset.dropout = 0.5
config.train = ml_collections.ConfigDict()
# Seed value for the random number generator
config.train.seed = 1234
# Training batch size. Adjust according to the GPU memory size
config.train.batchsize = 128
# Training steps. The default value is good for a medium size (~ 1 million cells) dataset
config.train.train_steps = 100000
# How frequent should we compute validation metrics
config.train.validation_interval = 10000
# Learning rate
config.train.lr = 1e-4
# Model weight L2 regularization factor
config.train.weight_decay = 1e-3
# Fraction of the training data reserved for validation purpose
config.train.val_split = 0.2
# Whether to train for the balanced loss or simple cross-entropy loss
config.train.balanced_loss = True
config.model = ml_collections.ConfigDict()
# Model type. Don't change
config.model.type = LinearPredictor
config.model.config = ml_collections.ConfigDict()
# Number of genes in the dataset
config.model.config.n_genes = 27504
# Number of cell types
config.model.config.dim_out = 68
# Dropout rate during training
config.model.config.dropout = 0.2
# Whether to normalized gene expression profile
config.model.config.normalize = False
# Whether to perform log1p transformation
config.model.config.log_transform = False
# Dimsion of the latent features
config.model.config.dim_hidden = 256
return config