Skip to content

Commit 2e2bab6

Browse files
authored
Merge branch 'main' into update-log-config
2 parents 7aad855 + e5cd96b commit 2e2bab6

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

src/templates/template-vision-classification/data.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,20 @@ def setup_data(config: Any):
1515
#::: if (it.use_dist) { :::#
1616
local_rank = idist.get_local_rank()
1717
#::: } :::#
18-
transform = T.Compose(
18+
train_transform = T.Compose(
1919
[
20+
T.Pad(4),
21+
T.RandomCrop(32, fill=128),
22+
T.RandomHorizontalFlip(),
2023
T.ToTensor(),
21-
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
24+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
25+
]
26+
)
27+
28+
eval_transform = T.Compose(
29+
[
30+
T.ToTensor(),
31+
T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
2232
]
2333
)
2434

@@ -32,13 +42,13 @@ def setup_data(config: Any):
3242
root=config.data_path,
3343
train=True,
3444
download=True,
35-
transform=transform,
45+
transform=train_transform,
3646
)
3747
dataset_eval = torchvision.datasets.CIFAR10(
3848
root=config.data_path,
3949
train=False,
4050
download=True,
41-
transform=transform,
51+
transform=eval_transform,
4252
)
4353

4454
#::: if (it.use_dist) { :::#

src/templates/template-vision-classification/main.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ignite.distributed as idist
66
from data import setup_data
77
from ignite.engine import Events
8+
from ignite.handlers import PiecewiseLinear
89
from ignite.metrics import Accuracy, Loss
910
from ignite.utils import manual_seed
1011
from models import setup_model
@@ -31,6 +32,15 @@ def run(local_rank: int, config: Any):
3132
model = idist.auto_model(setup_model(config.model))
3233
optimizer = idist.auto_optim(optim.Adam(model.parameters(), lr=config.lr))
3334
loss_fn = nn.CrossEntropyLoss().to(device=device)
35+
milestones_values = [
36+
(0, 0.0),
37+
(
38+
len(dataloader_train),
39+
config.lr,
40+
),
41+
(config.max_epochs * len(dataloader_train), 0.0),
42+
]
43+
lr_scheduler = PiecewiseLinear(optimizer, "lr", milestones_values=milestones_values)
3444

3545
# trainer and evaluator
3646
trainer = setup_trainer(
@@ -54,10 +64,17 @@ def run(local_rank: int, config: Any):
5464
logger.info("Configuration: \n%s", pformat(vars(config)))
5565
trainer.logger = evaluator.logger = logger
5666

67+
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
68+
5769
# setup ignite handlers
5870
#::: if (it.save_training || it.save_evaluation) { :::#
5971
#::: if (it.save_training) { :::#
60-
to_save_train = {"model": model, "optimizer": optimizer, "trainer": trainer}
72+
to_save_train = {
73+
"model": model,
74+
"optimizer": optimizer,
75+
"trainer": trainer,
76+
"lr_scheduler": lr_scheduler,
77+
}
6178
#::: } else { :::#
6279
to_save_train = None
6380
#::: } :::#

0 commit comments

Comments
 (0)