Skip to content

Commit 15fba79

Browse files
committed
Fixed logger creation
1 parent c70578c commit 15fba79

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,6 @@ python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py model=
5656
#### 8 TPUs on Colab
5757

5858
```bash
59-
python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu
60-
``
59+
python -u main_fixmatch.py model=resnet18 distributed.backend=xla-tpu distributed.nproc_per_node=8
60+
# or python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu distributed.nproc_per_node=8
61+
```

main_fixmatch.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ def unpack_from_tensor(t):
3434
return sorted_op_names[k_index], bins, error
3535

3636

37-
def training(local_rank, cfg, logger):
37+
def training(local_rank, cfg):
38+
39+
logger = setup_logger(
40+
"FixMatch Training",
41+
distributed_rank=idist.get_rank()
42+
)
3843

3944
if local_rank == 0:
4045
logger.info(cfg.pretty())
@@ -176,11 +181,7 @@ def update_cta_rates():
176181
def main(cfg: DictConfig) -> None:
177182

178183
with idist.Parallel(backend=cfg.distributed.backend, nproc_per_node=cfg.distributed.nproc_per_node) as parallel:
179-
logger = setup_logger(
180-
"FixMatch Training",
181-
distributed_rank=idist.get_rank()
182-
)
183-
parallel.run(training, cfg, logger)
184+
parallel.run(training, cfg)
184185

185186

186187
if __name__ == "__main__":

main_fully_supervised.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
import trainers
99

1010

11-
def training(local_rank, cfg, logger):
11+
def training(local_rank, cfg):
12+
13+
logger = setup_logger(
14+
"Fully-Supervised Training",
15+
distributed_rank=idist.get_rank()
16+
)
1217

1318
if local_rank == 0:
1419
logger.info(cfg.pretty())
@@ -68,11 +73,7 @@ def train_step(engine, batch):
6873
def main(cfg: DictConfig) -> None:
6974

7075
with idist.Parallel(backend=cfg.distributed.backend, nproc_per_node=cfg.distributed.nproc_per_node) as parallel:
71-
logger = setup_logger(
72-
"Fully-Supervised Training",
73-
distributed_rank=idist.get_rank()
74-
)
75-
parallel.run(training, cfg, logger)
76+
parallel.run(training, cfg)
7677

7778

7879
if __name__ == "__main__":

0 commit comments

Comments
 (0)