Skip to content
This repository has been archived by the owner on Sep 19, 2024. It is now read-only.

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Jun 12, 2024
1 parent 5471b59 commit 5df0bb9
Show file tree
Hide file tree
Showing 22 changed files with 6,891 additions and 181 deletions.
41 changes: 4 additions & 37 deletions config/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,6 @@ trainer:
project: ${project}
save_dir: /pasteur/zeus/projets/p02/ml4ig_hot/Users/jkalfon/
offline: True
callbacks:
- class_path: scprint.trainer.TrainingMode
init_args:
do_denoise: True
noise:
- 0.6
do_cce: False
do_ecs: True
do_mvc: False
do_generate: True
do_adv_cls: False
do_next_tp: False
do_adv_batch: False
run_full_forward: False
do_cls: True
class_scale: 1.5
warmup_duration: 500
fused_adam: True
mask_ratio: []
- class_path: lightning.pytorch.callbacks.StochasticWeightAveraging
init_args:
swa_lrs: 0.03
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
save_top_k: 6
save_last: True
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val_loss
patience: 3
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
#- class_path: lightning.pytorch.callbacks.LearningRateFinder
#init_args:
# mode: exponential
Expand All @@ -74,6 +40,7 @@ model:
fused_mlp: False
fused_bias_fc: False
drop_path_rate: 0.02
freeze_embeddings: True
pred_embedding:
- cell_type_ontology_term_id
- disease_ontology_term_id
Expand All @@ -88,14 +55,14 @@ data:
collection_name: all no zhang13M #preprocessed dataset
how: random expr
max_len: 2200
weight_scaler: 100
weight_scaler: 50
do_gene_pos: ./data/main/biomart_pos.parquet
add_zero_genes: 0
train_oversampling_per_epoch: 0.4
train_oversampling_per_epoch: 0.3
validation_split: 0.02
test_split: 0.02
batch_size: 64
num_workers: 16
num_workers: 12
# TODO: drop tissue & dev stage until part or is taken in account
hierarchical_clss:
- cell_type_ontology_term_id
Expand Down
18 changes: 12 additions & 6 deletions config/masking.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@ trainer:
- class_path: scprint.trainer.TrainingMode
init_args:
do_denoise: False
noise:
- 0.5
noise: []
do_cce: False
do_ecs: False
do_mvc: False
do_generate: True
do_ecs: True
do_mvc: True
do_generate: False
do_adv_cls: False
do_next_tp: False
do_adv_batch: False
run_full_forward: False
do_cls: True
class_scale: 1.5
d warmup_duration: 500
warmup_duration: 500
fused_adam: True
mask_ratio: [0.4]
- class_path: lightning.pytorch.callbacks.StochasticWeightAveraging
Expand All @@ -33,3 +32,10 @@ d warmup_duration: 500
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
model:
nhead: 4
nlayers: 8
layers_cls: [256]
d_model: 256
data:
collection_name: all no zhang13M #preprocessed dataset
41 changes: 1 addition & 40 deletions config/pretrain_large.yml
Original file line number Diff line number Diff line change
@@ -1,48 +1,9 @@
trainer:
callbacks:
- class_path: scprint.trainer.TrainingMode
init_args:
do_denoise: True
noise:
- 0.6
do_cce: True
do_ecs: False
do_mvc: False
do_generate: True
do_adv_cls: False
do_next_tp: False
do_adv_batch: False
run_full_forward: False
do_cls: True
class_scale: 1
warmup_duration: 500
fused_adam: True
mask_ratio: []
- class_path: lightning.pytorch.callbacks.StochasticWeightAveraging
init_args:
swa_lrs: 0.03
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
save_top_k: 6
save_last: True
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val_loss
patience: 3
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
strategy: ddp_find_unused_parameters_true
model:
nhead: 8
nlayers: 16 #used to be 12
layers_cls: [512]
d_model: 512
freeze_embeddings: False
data:
max_len: 2200
weight_scaler: 100
train_oversampling_per_epoch: 0.5
batch_size: 16
collection_name: all no zhang13M
num_workers: 12
13 changes: 0 additions & 13 deletions config/pretrain_medium.yml
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
trainer:
limit_train_batches: 5000
limit_val_batches: 1000
max_time:
hours: 72
model:
nhead: 4
nlayers: 8
freeze_embeddings: True
layers_cls: [256]
d_model: 256
data:
collection_name: all no zhang13M #preprocessed dataset
max_len: 2200
weight_scaler: 40
train_oversampling_per_epoch: 0.4
validation_split: 0.02
test_split: 0.02
batch_size: 64
num_workers: 12
57 changes: 3 additions & 54 deletions config/pretrain_vlarge.yml
Original file line number Diff line number Diff line change
@@ -1,68 +1,17 @@
trainer:
strategy: ddp_find_unused_parameters_true
num_nodes: 1
max_time:
hours: 72
log_every_n_steps: 300
precision: 16-mixed
gradient_clip_val: 500
limit_train_batches: 36000
limit_val_batches: 4000
limit_train_batches: 14000
limit_val_batches: 2000
reload_dataloaders_every_n_epochs: 1
accumulate_grad_batches: 6
callbacks:
- class_path: scprint.trainer.TrainingMode
init_args:
do_denoise: True
noise:
- 0.6
do_cce: True
do_ecs: False
do_mvc: False
do_generate: True
do_adv_cls: False
do_next_tp: False
do_adv_batch: False
run_full_forward: False
do_cls: True
class_scale: 1
warmup_duration: 1000 # prev 3000 (maybe it helped?)
fused_adam: True
mask_ratio: []
#- class_path: lightning.pytorch.callbacks.StochasticWeightAveraging
# init_args:
# swa_lrs: 0.0
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
save_top_k: 6
save_last: True
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val_loss
patience: 8
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
#- class_path: lightning.pytorch.callbacks.LearningRateFinder
#init_args:
# mode: exponential

#plugins:
# - class_path: lightning.pytorch.plugins.environments.SLURMEnvironment
# requeue_signal: signal.SIGHUP
accumulate_grad_batches: 2
model:
nhead: 20
lr: 0.00001
nlayers: 32
layers_cls: [512]
d_model: 1280
freeze_embeddings: True
data:
collection_name: all no zhang13M #preprocessed dataset #all no zhang13M
how: random expr
max_len: 2200
weight_scaler: 2000
train_oversampling_per_epoch: 1
batch_size: 3
num_workers: 9
Loading

0 comments on commit 5df0bb9

Please sign in to comment.