5
5
import ignite .distributed as idist
6
6
from data import setup_data
7
7
from ignite .engine import Events
8
+ from ignite .handlers import PiecewiseLinear
8
9
from ignite .metrics import Accuracy , Loss
9
10
from ignite .utils import manual_seed
10
11
from models import setup_model
@@ -31,6 +32,15 @@ def run(local_rank: int, config: Any):
31
32
model = idist .auto_model (setup_model (config .model ))
32
33
optimizer = idist .auto_optim (optim .Adam (model .parameters (), lr = config .lr ))
33
34
loss_fn = nn .CrossEntropyLoss ().to (device = device )
35
+ milestones_values = [
36
+ (0 , 0.0 ),
37
+ (
38
+ len (dataloader_train ),
39
+ config .lr ,
40
+ ),
41
+ (config .max_epochs * len (dataloader_train ), 0.0 ),
42
+ ]
43
+ lr_scheduler = PiecewiseLinear (optimizer , "lr" , milestones_values = milestones_values )
34
44
35
45
# trainer and evaluator
36
46
trainer = setup_trainer (
@@ -54,10 +64,17 @@ def run(local_rank: int, config: Any):
54
64
logger .info ("Configuration: \n %s" , pformat (vars (config )))
55
65
trainer .logger = evaluator .logger = logger
56
66
67
+ trainer .add_event_handler (Events .ITERATION_COMPLETED , lr_scheduler )
68
+
57
69
# setup ignite handlers
58
70
#::: if (it.save_training || it.save_evaluation) { :::#
59
71
#::: if (it.save_training) { :::#
60
- to_save_train = {"model" : model , "optimizer" : optimizer , "trainer" : trainer }
72
+ to_save_train = {
73
+ "model" : model ,
74
+ "optimizer" : optimizer ,
75
+ "trainer" : trainer ,
76
+ "lr_scheduler" : lr_scheduler ,
77
+ }
61
78
#::: } else { :::#
62
79
to_save_train = None
63
80
#::: } :::#
0 commit comments