Skip to content

Commit 6b0e1f2

Browse files
committed
[WIP] fully supervised training app
1 parent 8179811 commit 6b0e1f2

File tree

17 files changed

+845
-490
lines changed

17 files changed

+845
-490
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Based on ["FixMatch: Simplifying Semi-Supervised Learning withConsistency and Co
77
## Requirements
88

99
```bash
10+
pip install --upgrade --pre hydra-core
1011
pip install --upgrade --pre pytorch-ignite
1112
```
1213

config/dataflow/cifar10.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# @package _group_
2+
name: cifar10
3+
4+
data_path: "/tmp/cifar10"
5+
6+
batch_size: 64
7+
num_workers: 12

config/fixmatch.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
hydra:
2+
run:
3+
dir: /tmp/output-fixmatch-cifar10-hydra/fully_supervised/${now:%Y%m%d-%H%M%S}
4+
5+
seed: 543
6+
model: "resnet18"
7+
8+
9+
10+
defaults:
11+
- dataset: cifar10
12+
- solver: default
13+
14+
15+

config/fully_supervised.yaml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
hydra:
2+
run:
3+
dir: /tmp/output-fixmatch-cifar10-hydra/fully_supervised/${now:%Y%m%d-%H%M%S}
4+
job_logging:
5+
handlers:
6+
console:
7+
level: WARN
8+
root:
9+
level: WARN
10+
11+
name: fully-supervised
12+
13+
seed: 543
14+
debug: false
15+
16+
# model name (from torchvision) to setup model to train. For Wide-Resnet, use "WRN-28-2"
17+
model: "resnet18"
18+
num_classes: 10
19+
20+
ema_decay: 0.999
21+
22+
defaults:
23+
- dataflow: cifar10
24+
- solver: default
25+
- ssl: full_sup
26+
27+
28+
distributed:
29+
# backend to use for distributed configuration. Possible values: None, "nccl", "xla-tpu", "gloo" etc. Default, None.
30+
backend: null
31+
# optional argument to setup number of processes per node. It is useful, when main python process is spawning training as child processes.
32+
nproc_per_node: null
33+
34+
35+
online_exp_tracking:
36+
wandb: false

config/solver/default.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# @package _group_
2+
3+
num_epochs: 1024
4+
5+
epoch_length: 128 # epoch_length * num_epochs == 2 ** 20
6+
7+
checkpoint_every: 500
8+
9+
validate_every: 1
10+
11+
resume_from: null
12+
13+
optimizer:
14+
cls: torch.optim.SGD
15+
params:
16+
lr: 0.01
17+
momentum: 0.9
18+
weight_decay: 0.0001
19+
nesterov: false
20+
21+
22+
supervised_criterion:
23+
cls: torch.nn.CrossEntropyLoss
24+
25+
26+
lr_scheduler:
27+
cls: torch.optim.lr_scheduler.CosineAnnealingLR
28+
params:
29+
eta_min: 0.0
30+
T_max: null

config/ssl/full_sup.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# @package _group_
2+
3+
num_train_samples_per_class: 25

configs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ def get_default_config():
2828
"model": "WRN-28-2",
2929
"momentum": 0.9,
3030
"weight_decay": 0.0005,
31+
3132
"batch_size": batch_size,
3233
"num_workers": 12,
34+
3335
"num_epochs": 1024,
3436
"epoch_length": 2 ** 16 // batch_size, # epoch_length * num_epochs == 2 ** 20
3537
"learning_rate": 0.03,
36-
"validate_every": 1,
3738

39+
"validate_every": 1,
3840
# Logging:
3941
"display_iters": True,
4042
"checkpoint_every": 200,

ctaugment/__init__.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import json
2+
from collections import OrderedDict
3+
4+
from ctaugment.ctaugment import *
5+
6+
7+
class StorableCTAugment(CTAugment):
8+
9+
def load_state_dict(self, state):
10+
for k in ["decay", "depth", "th", "rates"]:
11+
assert k in state, "{} not in {}".format(k, state.keys())
12+
setattr(self, k, state[k])
13+
14+
def state_dict(self):
15+
return OrderedDict([(k, getattr(self, k)) for k in ["decay", "depth", "th", "rates"]])
16+
17+
18+
def get_default_cta():
19+
return StorableCTAugment()
20+
21+
22+
def cta_apply(pil_img, ops):
23+
if ops is None:
24+
return pil_img
25+
for op, args in ops:
26+
pil_img = OPS[op].f(pil_img, *args)
27+
return pil_img
28+
29+
30+
def deserialize(policy_str):
31+
return [OP(f=x[0], bins=x[1]) for x in json.loads(policy_str)]
32+
33+
34+
def stats(cta):
35+
return '\n'.join('%-16s %s' % (k, ' / '.join(' '.join('%.2f' % x for x in cta.rate_to_p(rate))
36+
for rate in cta.rates[k]))
37+
for k in sorted(OPS.keys()))
38+
39+
40+
def interleave(x, batch, inverse=False):
41+
"""
42+
TF code
43+
def interleave(x, batch):
44+
s = x.get_shape().as_list()
45+
return tf.reshape(tf.transpose(tf.reshape(x, [-1, batch] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] + s[1:])
46+
"""
47+
shape = x.shape
48+
axes = [batch, -1] if inverse else [-1, batch]
49+
return x.reshape(*axes, *shape[1:]).transpose(0, 1).reshape(-1, *shape[1:])
50+
51+
52+
def deinterleave(x, batch):
53+
return interleave(x, batch, inverse=True)
File renamed without changes.

dataflow/__init__.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from functools import partial
2+
3+
from torch.utils.data import Dataset
4+
5+
from ignite.utils import convert_tensor
6+
7+
8+
class TransformedDataset(Dataset):
9+
10+
def __init__(self, dataset, transforms):
11+
self.dataset = dataset
12+
self.transforms = transforms
13+
14+
def __getitem__(self, i):
15+
dp = self.dataset[i]
16+
return self.transforms(dp)
17+
18+
def __len__(self):
19+
return len(self.dataset)
20+
21+
22+
def sup_prepare_batch(batch, device, non_blocking):
23+
x = convert_tensor(batch["image"], device, non_blocking)
24+
y = convert_tensor(batch["target"], device, non_blocking)
25+
return x, y
26+
27+
28+
def cycle(dataloader):
29+
while True:
30+
for b in dataloader:
31+
yield b
32+
33+
34+
def get_supervised_train_loader(dataset_name, root, num_train_samples_per_class, download=True, **dataloader_kwargs):
35+
if dataset_name == "cifar10":
36+
from dataflow.cifar10 import get_supervised_trainset, get_supervised_train_loader, weak_transforms
37+
38+
train_dataset = get_supervised_trainset(
39+
root, num_train_samples_per_class=num_train_samples_per_class, download=download
40+
)
41+
42+
return get_supervised_train_loader(
43+
train_dataset, **dataloader_kwargs
44+
)
45+
46+
else:
47+
raise ValueError("Unhandled dataset: {}".format(dataset_name))
48+
49+
50+
def get_test_loader(dataset_name, root, download=True, **dataloader_kwargs):
51+
if dataset_name == "cifar10":
52+
from dataflow.cifar10 import get_test_loader
53+
54+
return get_test_loader(root=root, download=download, **dataloader_kwargs)
55+
56+
else:
57+
raise ValueError("Unhandled dataset: {}".format(dataset_name))
58+
59+
60+
def get_unsupervised_train_loader(dataset_name, root, cta, download=True, **dataloader_kwargs):
61+
if dataset_name == "cifar10":
62+
from dataflow import cifar10
63+
64+
full_train_dataset = cifar10.get_supervised_trainset(
65+
root, num_train_samples_per_class=None, download=download
66+
)
67+
68+
strong_transforms = partial(cifar10.cta_image_transforms, cta=cta)
69+
70+
return get_unsupervised_train_loader(
71+
full_train_dataset,
72+
transforms_weak=cifar10.weak_transforms,
73+
transforms_strong=strong_transforms,
74+
**dataloader_kwargs
75+
)
76+
77+
else:
78+
raise ValueError("Unhandled dataset: {}".format(dataset_name))
79+
80+
81+
def get_cta_probe_loader(dataset_name, root, num_train_samples_per_class, cta, download=True, **dataloader_kwargs):
82+
if dataset_name == "cifar10":
83+
from dataflow.cifar10 import get_supervised_trainset, get_cta_probe_loader
84+
85+
train_dataset = get_supervised_trainset(
86+
root, num_train_samples_per_class=num_train_samples_per_class, download=download
87+
)
88+
89+
return get_cta_probe_loader(
90+
train_dataset,
91+
cta=cta,
92+
**dataloader_kwargs
93+
)
94+
95+
else:
96+
raise ValueError("Unhandled dataset: {}".format(dataset_name))
97+

0 commit comments

Comments
 (0)