Skip to content

Commit c70578c

Browse files
committed
Updated configs and README
1 parent febe9c8 commit c70578c

File tree

3 files changed

+66
-56
lines changed

3 files changed

+66
-56
lines changed

README.md

Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,82 +7,54 @@ Based on ["FixMatch: Simplifying Semi-Supervised Learning withConsistency and Co
77
## Requirements
88

99
```bash
10-
pip install --upgrade --pre hydra-core
10+
pip install --upgrade --pre hydra-core tensorboardX
1111
pip install --upgrade --pre pytorch-ignite
1212
```
1313

1414
## Training
1515

1616
```bash
17-
python -u main_fixmatch.py
18-
# or python -u main_fixmatch.py --params "data_path=/path/to/cifar10"
17+
python -u main_fixmatch.py model=WRN-28-2
1918
```
2019

2120
This script automatically trains in multiple GPUs (`torch.nn.DistributedParallel`).
2221

23-
### Distributed Data Parallel (DDP) on multiple GPUs (Experimental)
22+
If it is needed to specify input/output folder :
23+
```
24+
python -u main_fixmatch.py dataflow.data_path=/data/cifar10/ hydra.run.dir=/output-fixmatch model=WRN-28-2
25+
```
2426

25-
For example, training on 2 GPUs
27+
To use wandb logger, we need login and run with `online_exp_tracking.wandb=true`:
2628
```bash
27-
python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py --params="distributed=True"
29+
wandb login <token>
30+
python -u main_fixmatch.py model=WRN-28-2 online_exp_tracking.wandb=true
2831
```
2932

30-
### TPU(s) on Colab (Experimental)
31-
32-
#### Installation
33+
To see other options:
3334
```bash
34-
VERSION = "1.5"
35-
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
36-
!python pytorch-xla-env-setup.py --version $VERSION
35+
python -u main_fixmatch.py --help
3736
```
3837

39-
#### Single TPU
38+
### Training curves visualization
39+
40+
By default, we use Tensorboard to log training curves
41+
4042
```bash
41-
python -u main_fixmatch.py --params="device='xla'"
43+
tensorboard --logdir=/tmp/output-fixmatch-cifar10-hydra/
4244
```
4345

44-
#### 8 TPUs on Colab
4546

47+
### Distributed Data Parallel (DDP) on multiple GPUs (Experimental)
48+
49+
For example, training on 2 GPUs
4650
```bash
47-
python -u main_fixmatch.py --params="device='xla';distributed=True"
51+
python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py model=WRN-28-2 distributed.backend=nccl
4852
```
4953

50-
## TODO
51-
52-
* [x] Resume training from existing checkpoint:
53-
* [x] save/load CTA
54-
* [x] save ema model
55-
56-
* [ ] DDP:
57-
* [x] Synchronize CTA across processes
58-
* [x] Unified GPU and TPU approach
59-
* [ ] Bug: DDP performances are worse than DP on the first epochs
60-
61-
* [ ] Logging to an online platform: NeptuneML or Trains or W&B
62-
63-
* [ ] Replace PIL augmentations with Albumentations
64-
65-
```python
66-
class BlurLimitSampler:
67-
def __init__(self, blur, weights):
68-
self.blur = blur # [3, 5, 7]
69-
self.weights = weights # [0.1, 0.5, 0.4]
70-
def get_params(self):
71-
return {"ksize": int(random.choice(self.blur, p=self.weights))}
72-
73-
class Blur(ImageOnlyTransform):
74-
def __init__(self, blur_limit, always_apply=False, p=0.5):
75-
super(Blur, self).__init__(always_apply, p)
76-
self.blur_limit = blur_limit
77-
78-
def apply(self, image, ksize=3, **params):
79-
return F.blur(image, ksize)
80-
81-
def get_params(self):
82-
if isinstance(self.blur_limit, BlurLimitSampler):
83-
return self.blur_limit.get_params()
84-
return {"ksize": int(random.choice(np.arange(self.blur_limit[0], self.blur_limit[1] + 1, 2)))}
85-
86-
def get_transform_init_args_names(self):
87-
return ("blur_limit",)
88-
```
54+
### TPU(s) on Colab (Experimental)
55+
56+
#### 8 TPUs on Colab
57+
58+
```bash
59+
python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu
60+
``

TODO

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
## TODO
2+
3+
* [x] Resume training from existing checkpoint:
4+
* [x] save/load CTA
5+
* [x] save ema model
6+
7+
* [ ] DDP:
8+
* [x] Synchronize CTA across processes
9+
* [ ] Bug: DDP performances are worse than DP on the first epochs
10+
11+
* [x] Logging to an online platform: W&B
12+
13+
* [ ] Replace PIL augmentations with Albumentations
14+
15+
```python
16+
class BlurLimitSampler:
17+
def __init__(self, blur, weights):
18+
self.blur = blur # [3, 5, 7]
19+
self.weights = weights # [0.1, 0.5, 0.4]
20+
def get_params(self):
21+
return {"ksize": int(random.choice(self.blur, p=self.weights))}
22+
23+
class Blur(ImageOnlyTransform):
24+
def __init__(self, blur_limit, always_apply=False, p=0.5):
25+
super(Blur, self).__init__(always_apply, p)
26+
self.blur_limit = blur_limit
27+
28+
def apply(self, image, ksize=3, **params):
29+
return F.blur(image, ksize)
30+
31+
def get_params(self):
32+
if isinstance(self.blur_limit, BlurLimitSampler):
33+
return self.blur_limit.get_params()
34+
return {"ksize": int(random.choice(np.arange(self.blur_limit[0], self.blur_limit[1] + 1, 2)))}
35+
36+
def get_transform_init_args_names(self):
37+
return ("blur_limit",)
38+
```

config/solver/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ resume_from: null
1313
optimizer:
1414
cls: torch.optim.SGD
1515
params:
16-
lr: 0.01
16+
lr: 0.03
1717
momentum: 0.9
1818
weight_decay: 0.0001
1919
nesterov: false

0 commit comments

Comments
 (0)