1
1
import torch
2
- import torch .distributed as dist
3
2
3
+ import ignite .distributed as idist
4
4
from ignite .engine import Events
5
+ from ignite .utils import manual_seed , setup_logger
6
+
7
+ import hydra
8
+ from hydra .utils import instantiate
9
+ from omegaconf import DictConfig
5
10
6
11
import utils
7
- from base_train import main , BaseTrainer
8
- from configs import get_default_config
9
- from ctaugment import OPS
12
+ import trainers
13
+ from ctaugment import get_default_cta , OPS , interleave , deinterleave
10
14
11
15
12
16
sorted_op_names = sorted (list (OPS .keys ()))
@@ -30,22 +34,43 @@ def unpack_from_tensor(t):
30
34
return sorted_op_names [k_index ], bins , error
31
35
32
36
33
- class FixMatchTrainer ( BaseTrainer ):
37
+ def training ( local_rank , cfg , logger ):
34
38
35
- output_names = ["total_loss" , "sup_loss" , "unsup_loss" , "mask" ]
39
+ if local_rank == 0 :
40
+ logger .info (cfg .pretty ())
41
+
42
+ rank = idist .get_rank ()
43
+ manual_seed (cfg .seed + rank )
44
+ device = idist .device ()
45
+
46
+ model , ema_model , optimizer , sup_criterion , lr_scheduler = utils .initialize (cfg )
47
+
48
+ unsup_criterion = instantiate (cfg .solver .unsupervised_criterion )
49
+
50
+ cta = get_default_cta ()
51
+
52
+ supervised_train_loader , test_loader , unsup_train_loader , cta_probe_loader = \
53
+ utils .get_dataflow (cfg , cta = cta , with_unsup = True )
36
54
37
- def train_step (self , engine , batch ):
38
- self . model .train ()
39
- self . optimizer .zero_grad ()
55
+ def train_step (engine , batch ):
56
+ model .train ()
57
+ optimizer .zero_grad ()
40
58
41
- x , y = batch ["sup_batch" ]
42
- weak_x , strong_x = batch ["unsup_batch" ]
59
+ x , y = batch ["sup_batch" ]["image" ], batch ["sup_batch" ]["target" ]
60
+ if x .device != device :
61
+ x = x .to (device , non_blocking = True )
62
+ y = y .to (device , non_blocking = True )
63
+
64
+ weak_x , strong_x = batch ["unsup_batch" ]["image" ], batch ["unsup_batch" ]["strong_aug" ]
65
+ if weak_x .device != device :
66
+ weak_x = weak_x .to (device , non_blocking = True )
67
+ strong_x = strong_x .to (device , non_blocking = True )
43
68
44
69
# according to TF code: single forward pass on concat data: [x, weak_x, strong_x]
45
- le = 2 * self . config [ " mu_ratio" ] + 1
46
- x_cat = utils . interleave (torch .cat ([x , weak_x , strong_x ], dim = 0 ), le )
47
- y_pred_cat = self . model (x_cat )
48
- y_pred_cat = utils . deinterleave (y_pred_cat , le )
70
+ le = 2 * engine . state . mu_ratio + 1
71
+ x_cat = interleave (torch .cat ([x , weak_x , strong_x ], dim = 0 ), le )
72
+ y_pred_cat = model (x_cat )
73
+ y_pred_cat = deinterleave (y_pred_cat , le )
49
74
50
75
idx1 = len (x )
51
76
idx2 = idx1 + len (weak_x )
@@ -54,25 +79,20 @@ def train_step(self, engine, batch):
54
79
y_strong_preds = y_pred_cat [idx2 :, ...] # logits_strong
55
80
56
81
# supervised learning:
57
- sup_loss = self . sup_criterion (y_pred , y )
82
+ sup_loss = sup_criterion (y_pred , y )
58
83
59
84
# unsupervised learning:
60
85
y_weak_probas = torch .softmax (y_weak_preds , dim = 1 ).detach ()
61
86
y_pseudo = y_weak_probas .argmax (dim = 1 )
62
87
max_y_weak_probas , _ = y_weak_probas .max (dim = 1 )
63
- unsup_loss_mask = (max_y_weak_probas >= self .confidence_threshold ).float ()
64
- unsup_loss = (self . unsup_criterion (y_strong_preds , y_pseudo ) * unsup_loss_mask ).mean ()
88
+ unsup_loss_mask = (max_y_weak_probas >= engine . state .confidence_threshold ).float ()
89
+ unsup_loss = (unsup_criterion (y_strong_preds , y_pseudo ) * unsup_loss_mask ).mean ()
65
90
66
- total_loss = sup_loss + self .lambda_u * unsup_loss
91
+ total_loss = sup_loss + engine . state .lambda_u * unsup_loss
67
92
68
- if self .config ["with_nv_amp_level" ] is not None :
69
- from apex import amp
70
- with amp .scale_loss (total_loss , self .optimizer ) as scaled_loss :
71
- scaled_loss .backward ()
72
- else :
73
- total_loss .backward ()
93
+ total_loss .backward ()
74
94
75
- self . optimizer .step ()
95
+ optimizer .step ()
76
96
77
97
return {
78
98
"total_loss" : total_loss .item (),
@@ -81,57 +101,87 @@ def train_step(self, engine, batch):
81
101
"mask" : unsup_loss_mask .mean ().item () # this should not be averaged for DDP
82
102
}
83
103
84
- def setup (self , ** kwargs ):
85
- super (FixMatchTrainer , self ).setup (** kwargs )
86
- self .confidence_threshold = self .config ["confidence_threshold" ]
87
- self .lambda_u = self .config ["lambda_u" ]
88
- self .add_event_handler (Events .ITERATION_COMPLETED , self .update_cta_rates )
89
- self .distributed = dist .is_available () and dist .is_initialized ()
104
+ output_names = ["total_loss" , "sup_loss" , "unsup_loss" , "mask" ]
90
105
91
- def update_cta_rates (self ):
92
- x , y , policies = self .state .batch ["cta_probe_batch" ]
93
- self .ema_model .eval ()
106
+ trainer = trainers .create_trainer (
107
+ train_step ,
108
+ output_names = output_names ,
109
+ model = model ,
110
+ ema_model = ema_model ,
111
+ optimizer = optimizer ,
112
+ lr_scheduler = lr_scheduler ,
113
+ supervised_train_loader = supervised_train_loader ,
114
+ test_loader = test_loader ,
115
+ cfg = cfg ,
116
+ logger = logger ,
117
+ cta = cta ,
118
+ unsup_train_loader = unsup_train_loader ,
119
+ cta_probe_loader = cta_probe_loader
120
+ )
121
+
122
+ trainer .state .confidence_threshold = cfg .ssl .confidence_threshold
123
+ trainer .state .lambda_u = cfg .ssl .lambda_u
124
+ trainer .state .mu_ratio = cfg .ssl .mu_ratio
125
+
126
+ distributed = idist .get_world_size () > 1
127
+
128
+ @trainer .on (Events .ITERATION_COMPLETED )
129
+ def update_cta_rates ():
130
+ batch = trainer .state .batch
131
+ x , y = batch ["cta_probe_batch" ]["image" ], batch ["cta_probe_batch" ]["target" ]
132
+ if x .device != device :
133
+ x = x .to (device , non_blocking = True )
134
+ y = y .to (device , non_blocking = True )
135
+
136
+ policies = batch ["cta_probe_batch" ]["policy" ]
137
+
138
+ ema_model .eval ()
94
139
with torch .no_grad ():
95
- y_pred = self . ema_model (x )
140
+ y_pred = ema_model (x )
96
141
y_probas = torch .softmax (y_pred , dim = 1 ) # (N, C)
97
142
98
- if not self . distributed :
99
- for y_proba , t , policy in zip (y_probas , y , policies ):
143
+ if distributed :
144
+ for y_proba , t , policy in zip (y_probas , y , policies ):
100
145
error = y_proba
101
146
error [t ] -= 1
102
147
error = torch .abs (error ).sum ()
103
- self . cta .update_rates (policy , 1.0 - 0.5 * error .item ())
148
+ cta .update_rates (policy , 1.0 - 0.5 * error .item ())
104
149
else :
105
150
error_per_op = []
106
151
for y_proba , t , policy in zip (y_probas , y , policies ):
107
152
error = y_proba
108
153
error [t ] -= 1
109
154
error = torch .abs (error ).sum ()
110
- for k , bins in policy :
155
+ for k , bins in policy :
111
156
error_per_op .append (pack_as_tensor (k , bins , error ))
112
157
error_per_op = torch .stack (error_per_op )
113
- # all gather
114
- tensor_list = [
115
- torch .empty_like (error_per_op )
116
- for _ in range (dist .get_world_size ())
117
- ]
118
- dist .all_gather (tensor_list , error_per_op )
119
- tensor_list = torch .cat (tensor_list , dim = 0 )
158
+ # all gather
159
+ tensor_list = idist .all_gather (error_per_op )
120
160
# update cta rates
121
161
for t in tensor_list :
122
- k , bins , error = unpack_from_tensor (t )
123
- self .cta .update_rates ([(k , bins ), ], 1.0 - 0.5 * error )
162
+ k , bins , error = unpack_from_tensor (t )
163
+ cta .update_rates ([(k , bins ), ], 1.0 - 0.5 * error )
164
+
165
+ epoch_length = cfg .solver .epoch_length
166
+ num_epochs = cfg .solver .num_epochs if not cfg .debug else 2
167
+ try :
168
+ trainer .run (supervised_train_loader , epoch_length = epoch_length , max_epochs = num_epochs )
169
+ except Exception as e :
170
+ import traceback
171
+
172
+ print (traceback .format_exc ())
173
+
124
174
175
+ @hydra .main (config_path = "config" , config_name = "fixmatch" )
176
+ def main (cfg : DictConfig ) -> None :
125
177
126
- def get_fixmatch_config ():
127
- config = get_default_config ()
128
- config .update ({
129
- # FixMatch settings
130
- "confidence_threshold" : 0.95 ,
131
- "lambda_u" : 1.0 ,
132
- })
133
- return config
178
+ with idist .Parallel (backend = cfg .distributed .backend , nproc_per_node = cfg .distributed .nproc_per_node ) as parallel :
179
+ logger = setup_logger (
180
+ "FixMatch Training" ,
181
+ distributed_rank = idist .get_rank ()
182
+ )
183
+ parallel .run (training , cfg , logger )
134
184
135
185
136
186
if __name__ == "__main__" :
137
- main (FixMatchTrainer (), get_fixmatch_config () )
187
+ main ()
0 commit comments