@@ -7,82 +7,54 @@ Based on ["FixMatch: Simplifying Semi-Supervised Learning withConsistency and Co
7
7
## Requirements
8
8
9
9
``` bash
10
- pip install --upgrade --pre hydra-core
10
+ pip install --upgrade --pre hydra-core tensorboardX
11
11
pip install --upgrade --pre pytorch-ignite
12
12
```
13
13
14
14
## Training
15
15
16
16
``` bash
17
- python -u main_fixmatch.py
18
- # or python -u main_fixmatch.py --params "data_path=/path/to/cifar10"
17
+ python -u main_fixmatch.py model=WRN-28-2
19
18
```
20
19
21
20
This script automatically trains in multiple GPUs (` torch.nn.DistributedParallel ` ).
22
21
23
- ### Distributed Data Parallel (DDP) on multiple GPUs (Experimental)
22
+ If it is needed to specify input/output folder :
23
+ ```
24
+ python -u main_fixmatch.py dataflow.data_path=/data/cifar10/ hydra.run.dir=/output-fixmatch model=WRN-28-2
25
+ ```
24
26
25
- For example, training on 2 GPUs
27
+ To use wandb logger, we need login and run with ` online_exp_tracking.wandb=true ` :
26
28
``` bash
27
- python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py --params=" distributed=True"
29
+ wandb login < token>
30
+ python -u main_fixmatch.py model=WRN-28-2 online_exp_tracking.wandb=true
28
31
```
29
32
30
- ### TPU(s) on Colab (Experimental)
31
-
32
- #### Installation
33
+ To see other options:
33
34
``` bash
34
- VERSION = " 1.5"
35
- ! curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
36
- ! python pytorch-xla-env-setup.py --version $VERSION
35
+ python -u main_fixmatch.py --help
37
36
```
38
37
39
- #### Single TPU
38
+ ### Training curves visualization
39
+
40
+ By default, we use Tensorboard to log training curves
41
+
40
42
``` bash
41
- python -u main_fixmatch.py --params= " device='xla' "
43
+ tensorboard --logdir=/tmp/output-fixmatch-cifar10-hydra/
42
44
```
43
45
44
- #### 8 TPUs on Colab
45
46
47
+ ### Distributed Data Parallel (DDP) on multiple GPUs (Experimental)
48
+
49
+ For example, training on 2 GPUs
46
50
``` bash
47
- python -u main_fixmatch.py --params= " device='xla'; distributed=True "
51
+ python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py model=WRN-28-2 distributed.backend=nccl
48
52
```
49
53
50
- ## TODO
51
-
52
- * [x] Resume training from existing checkpoint:
53
- * [x] save/load CTA
54
- * [x] save ema model
55
-
56
- * [ ] DDP:
57
- * [x] Synchronize CTA across processes
58
- * [x] Unified GPU and TPU approach
59
- * [ ] Bug: DDP performances are worse than DP on the first epochs
60
-
61
- * [ ] Logging to an online platform: NeptuneML or Trains or W&B
62
-
63
- * [ ] Replace PIL augmentations with Albumentations
64
-
65
- ``` python
66
- class BlurLimitSampler :
67
- def __init__ (self , blur , weights ):
68
- self .blur = blur # [3, 5, 7]
69
- self .weights = weights # [0.1, 0.5, 0.4]
70
- def get_params (self ):
71
- return {" ksize" : int (random.choice(self .blur, p = self .weights))}
72
-
73
- class Blur (ImageOnlyTransform ):
74
- def __init__ (self , blur_limit , always_apply = False , p = 0.5 ):
75
- super (Blur, self ).__init__ (always_apply, p)
76
- self .blur_limit = blur_limit
77
-
78
- def apply (self , image , ksize = 3 , ** params ):
79
- return F.blur(image, ksize)
80
-
81
- def get_params (self ):
82
- if isinstance (self .blur_limit, BlurLimitSampler):
83
- return self .blur_limit.get_params()
84
- return {" ksize" : int (random.choice(np.arange(self .blur_limit[0 ], self .blur_limit[1 ] + 1 , 2 )))}
85
-
86
- def get_transform_init_args_names (self ):
87
- return (" blur_limit" ,)
88
- ```
54
+ ### TPU(s) on Colab (Experimental)
55
+
56
+ #### 8 TPUs on Colab
57
+
58
+ ``` bash
59
+ python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu
60
+ ` `
0 commit comments