forked from contrebande-labs/charred
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
94 lines (77 loc) · 2.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
# jax/flax
import jax
from flax import jax_utils
from flax.core.frozen_dict import unfreeze
from flax.training import train_state
from architecture import setup_model
# internal code
from args import parse_args
from optimizer import setup_optimizer
from training_loop import training_loop
from monitoring import wandb_close, wandb_init
def main():
args = parse_args()
output_dir = args.output_dir
load_pretrained = os.path.exists(output_dir) and os.path.isdir(output_dir)
# Setup WandB for logging & tracking
log_wandb = args.log_wandb
if log_wandb:
wandb_init(args)
# init random number generator
seed = args.seed
seed_rng = jax.random.PRNGKey(seed)
rng, training_from_scratch_rng_params = jax.random.split(seed_rng)
print("random generator setup...")
# Pretrained/freezed and training model setup
text_encoder, text_encoder_params, vae, vae_params, unet, unet_params = setup_model(
seed,
args.mixed_precision,
load_pretrained,
output_dir,
training_from_scratch_rng_params,
)
print("models setup...")
# Optimization & scheduling setup
optimizer = setup_optimizer(
args.learning_rate,
args.adam_beta1,
args.adam_beta2,
args.adam_epsilon,
args.adam_weight_decay,
args.max_grad_norm,
)
print("optimizer setup...")
# Training state setup
unet_training_state = train_state.TrainState.create(
apply_fn=unet,
params=unfreeze(unet_params),
tx=optimizer,
)
print("training state initialized...")
# JAX device data replication
replicated_state = jax_utils.replicate(unet_training_state)
replicated_text_encoder_params = jax_utils.replicate(text_encoder_params)
replicated_vae_params = jax_utils.replicate(vae_params)
print("states & params replicated to TPUs...")
# Train!
print("Training loop init...")
training_loop(
text_encoder,
replicated_text_encoder_params,
vae,
replicated_vae_params,
unet,
replicated_state,
rng,
args.max_train_steps,
args.num_train_epochs,
args.train_batch_size,
output_dir,
log_wandb,
)
print("Training loop done...")
if log_wandb:
wandb_close()
if __name__ == "__main__":
main()