File tree Expand file tree Collapse file tree 3 files changed +17
-14
lines changed Expand file tree Collapse file tree 3 files changed +17
-14
lines changed Original file line number Diff line number Diff line change @@ -56,5 +56,6 @@ python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py model=
56
56
#### 8 TPUs on Colab
57
57
58
58
``` bash
59
- python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu
60
- ` `
59
+ python -u main_fixmatch.py model=resnet18 distributed.backend=xla-tpu distributed.nproc_per_node=8
60
+ # or python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu distributed.nproc_per_node=8
61
+ ```
Original file line number Diff line number Diff line change @@ -34,7 +34,12 @@ def unpack_from_tensor(t):
34
34
return sorted_op_names [k_index ], bins , error
35
35
36
36
37
- def training (local_rank , cfg , logger ):
37
+ def training (local_rank , cfg ):
38
+
39
+ logger = setup_logger (
40
+ "FixMatch Training" ,
41
+ distributed_rank = idist .get_rank ()
42
+ )
38
43
39
44
if local_rank == 0 :
40
45
logger .info (cfg .pretty ())
@@ -176,11 +181,7 @@ def update_cta_rates():
176
181
def main (cfg : DictConfig ) -> None :
177
182
178
183
with idist .Parallel (backend = cfg .distributed .backend , nproc_per_node = cfg .distributed .nproc_per_node ) as parallel :
179
- logger = setup_logger (
180
- "FixMatch Training" ,
181
- distributed_rank = idist .get_rank ()
182
- )
183
- parallel .run (training , cfg , logger )
184
+ parallel .run (training , cfg )
184
185
185
186
186
187
if __name__ == "__main__" :
Original file line number Diff line number Diff line change 8
8
import trainers
9
9
10
10
11
- def training (local_rank , cfg , logger ):
11
+ def training (local_rank , cfg ):
12
+
13
+ logger = setup_logger (
14
+ "Fully-Supervised Training" ,
15
+ distributed_rank = idist .get_rank ()
16
+ )
12
17
13
18
if local_rank == 0 :
14
19
logger .info (cfg .pretty ())
@@ -68,11 +73,7 @@ def train_step(engine, batch):
68
73
def main (cfg : DictConfig ) -> None :
69
74
70
75
with idist .Parallel (backend = cfg .distributed .backend , nproc_per_node = cfg .distributed .nproc_per_node ) as parallel :
71
- logger = setup_logger (
72
- "Fully-Supervised Training" ,
73
- distributed_rank = idist .get_rank ()
74
- )
75
- parallel .run (training , cfg , logger )
76
+ parallel .run (training , cfg )
76
77
77
78
78
79
if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments