|
| 1 | +import os |
| 2 | +from typing import Dict |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +from sklearn.metrics import accuracy_score |
| 6 | +import torch |
| 7 | +from torch.utils.data import DataLoader |
| 8 | +from torch import nn |
| 9 | +from torchvision import datasets, transforms |
| 10 | +from torch.nn import functional as F |
| 11 | + |
| 12 | +from catalyst import dl |
| 13 | +from catalyst.utils import set_global_seed, prepare_cudnn |
| 14 | +from catalyst.loggers.console import ConsoleLogger |
| 15 | +from contextlib import contextmanager |
| 16 | + |
| 17 | + |
| 18 | +BATCH_SIZE = 256 |
| 19 | + |
| 20 | + |
| 21 | +class DenseBlock(nn.Module): |
| 22 | + def __init__(self, in_channels, growth_rate, num_layers, kernel_size=3): |
| 23 | + super(DenseBlock, self).__init__() |
| 24 | + self.layers = nn.ModuleList() |
| 25 | + for i in range(num_layers): |
| 26 | + self.layers.append( |
| 27 | + nn.Sequential( |
| 28 | + nn.Conv2d( |
| 29 | + in_channels + i * growth_rate, |
| 30 | + growth_rate, |
| 31 | + kernel_size, |
| 32 | + padding=1, |
| 33 | + bias=False, |
| 34 | + ), |
| 35 | + nn.BatchNorm2d(growth_rate), |
| 36 | + nn.ReLU(inplace=True), |
| 37 | + ) |
| 38 | + ) |
| 39 | + |
| 40 | + def forward(self, in_tensor): |
| 41 | + out = in_tensor |
| 42 | + for layer in self.layers: |
| 43 | + out = torch.cat((out, layer(out)), dim=1) |
| 44 | + return out |
| 45 | + |
| 46 | + |
| 47 | +class DenseNet(nn.Module): |
| 48 | + def __init__(self, num_classes, growth_rate=32, num_dense_blocks=6, num_layers_per_block=6): |
| 49 | + super(DenseNet, self).__init__() |
| 50 | + self.num_classes = num_classes |
| 51 | + |
| 52 | + self.in_preproc = nn.Sequential( |
| 53 | + nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), |
| 54 | + nn.BatchNorm2d(64), |
| 55 | + nn.ReLU(inplace=True), |
| 56 | + ) |
| 57 | + |
| 58 | + self.features = nn.ModuleList() |
| 59 | + in_channels = 64 |
| 60 | + for i in range(num_dense_blocks): |
| 61 | + self.features.append( |
| 62 | + DenseBlock(in_channels, growth_rate, num_layers_per_block) |
| 63 | + ) |
| 64 | + in_channels += growth_rate * num_layers_per_block |
| 65 | + if i < num_dense_blocks - 1: |
| 66 | + self.features.append( |
| 67 | + nn.Sequential( |
| 68 | + nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, bias=False), |
| 69 | + nn.BatchNorm2d(in_channels // 2), |
| 70 | + nn.ReLU(inplace=True), |
| 71 | + nn.AvgPool2d(kernel_size=2, stride=2), |
| 72 | + ) |
| 73 | + ) |
| 74 | + in_channels = in_channels // 2 |
| 75 | + self.features.append(nn.AdaptiveAvgPool2d(1)) |
| 76 | + |
| 77 | + self.classifier = nn.Sequential( |
| 78 | + nn.Dropout(p=0.25), |
| 79 | + nn.Linear(in_channels, 128), |
| 80 | + nn.ReLU(inplace=True), |
| 81 | + nn.Linear(128, self.num_classes), |
| 82 | + ) |
| 83 | + |
| 84 | + def forward(self, in_tensor): |
| 85 | + x = self.in_preproc(in_tensor) |
| 86 | + for block in self.features: |
| 87 | + x = block(x) |
| 88 | + feat = torch.flatten(x, 1) |
| 89 | + logits = self.classifier(feat) |
| 90 | + return logits |
| 91 | + |
| 92 | + |
| 93 | +def _format_metrics(dct): |
| 94 | + return " | ".join([f"{k}: {float(dct[k]):.03}" for k in sorted(dct.keys())]) |
| 95 | + |
| 96 | + |
| 97 | +class CustomLogger(ConsoleLogger): |
| 98 | + """Custom console logger for parameters and metrics. |
| 99 | + Output the metric into the console during experiment. |
| 100 | +
|
| 101 | + Note: |
| 102 | + We inherit ConsoleLogger to overwrite default Catalyst logging behaviour |
| 103 | + """ |
| 104 | + |
| 105 | + def log_metrics( |
| 106 | + self, |
| 107 | + metrics: Dict[str, float], |
| 108 | + scope: str = None, |
| 109 | + # experiment info |
| 110 | + run_key: str = None, |
| 111 | + global_epoch_step: int = 0, |
| 112 | + global_batch_step: int = 0, |
| 113 | + global_sample_step: int = 0, |
| 114 | + # stage info |
| 115 | + stage_key: str = None, |
| 116 | + stage_epoch_len: int = 0, |
| 117 | + stage_epoch_step: int = 0, |
| 118 | + stage_batch_step: int = 0, |
| 119 | + stage_sample_step: int = 0, |
| 120 | + # loader info |
| 121 | + loader_key: str = None, |
| 122 | + loader_batch_len: int = 0, |
| 123 | + loader_sample_len: int = 0, |
| 124 | + loader_batch_step: int = 0, |
| 125 | + loader_sample_step: int = 0, |
| 126 | + ) -> None: |
| 127 | + """Logs loader and epoch metrics to stdout.""" |
| 128 | + if scope == "loader": |
| 129 | + prefix = f"{loader_key} ({stage_epoch_step}/{stage_epoch_len}) " |
| 130 | + print(prefix + _format_metrics(metrics)) |
| 131 | + |
| 132 | + elif scope == "epoch": |
| 133 | + prefix = f"* Epoch ({stage_epoch_step}/{stage_epoch_len}) " |
| 134 | + print(prefix + _format_metrics(metrics["_epoch_"])) |
| 135 | + |
| 136 | + |
| 137 | +def get_transforms(): |
| 138 | + # ImageNet mean and std values for image pixel values |
| 139 | + means = np.array((0.4914, 0.4822, 0.4465)) |
| 140 | + stds = np.array((0.2023, 0.1994, 0.2010)) |
| 141 | + base_transforms = [transforms.ToTensor(), transforms.Normalize(means, stds)] |
| 142 | + augmented_transforms = [ |
| 143 | + transforms.RandomCrop(32, padding=4, padding_mode="reflect"), |
| 144 | + transforms.RandomHorizontalFlip(), |
| 145 | + transforms.ColorJitter(hue=0.01, brightness=0.3, contrast=0.3, saturation=0.3), |
| 146 | + ] |
| 147 | + augmented_transforms += base_transforms |
| 148 | + |
| 149 | + transform_basic = transforms.Compose(base_transforms) |
| 150 | + transform_augment = transforms.Compose(augmented_transforms) |
| 151 | + return transform_basic, transform_augment |
| 152 | + |
| 153 | + |
| 154 | +@contextmanager |
| 155 | +def infer(model): |
| 156 | + """Fully turns model state to inference (and restores it in the end)""" |
| 157 | + status = model.training |
| 158 | + model.train(False) |
| 159 | + with torch.no_grad(): |
| 160 | + try: |
| 161 | + yield None |
| 162 | + finally: |
| 163 | + model.train(status) |
| 164 | + |
| 165 | + |
| 166 | +def load_ckpt(path, model, device=torch.device("cpu")): |
| 167 | + """ |
| 168 | + Load saved checkpoint weights to model |
| 169 | + :param path: full path to checkpoint |
| 170 | + :param model: initialized model class nested from nn.Module() |
| 171 | + :param device: base torch device for validation |
| 172 | + :return: model with loaded 'state_dict' |
| 173 | + """ |
| 174 | + assert os.path.isfile(path), FileNotFoundError(f"no file: {path}") |
| 175 | + |
| 176 | + ckpt = torch.load(path, map_location=device) |
| 177 | + ckpt_dict = ckpt["model_state_dict"] |
| 178 | + model_dict = model.state_dict() |
| 179 | + ckpt_dict = {k: v for k, v in ckpt_dict.items() if k in model_dict} |
| 180 | + model_dict.update(ckpt_dict) |
| 181 | + model.load_state_dict(model_dict) |
| 182 | + return model |
| 183 | + |
| 184 | + |
| 185 | +@torch.no_grad() |
| 186 | +def validate_model(model, loader, device): |
| 187 | + """ |
| 188 | + Evaluate implemented model |
| 189 | + :param model: initialized model class nested from nn.Module() with loaded state dict |
| 190 | + :param loader batch data loader for evaluation set |
| 191 | + :param device: base torch device for validation |
| 192 | + :return: dict performance metrics |
| 193 | + """ |
| 194 | + label_list = [] |
| 195 | + pred_list = [] |
| 196 | + model.train(False) |
| 197 | + model = model.to(device) |
| 198 | + |
| 199 | + for data_tensor, lbl_tensor in loader: |
| 200 | + lbl_values = lbl_tensor.cpu().view(-1).tolist() |
| 201 | + label_list.extend(lbl_values) |
| 202 | + logits = model(data_tensor.to(device)) |
| 203 | + scores = F.softmax(logits.detach().cpu(), 1).numpy() |
| 204 | + pred_labels = np.argmax(scores, 1) |
| 205 | + pred_list.extend(pred_labels.ravel().tolist()) |
| 206 | + |
| 207 | + labels = np.array(label_list) |
| 208 | + predicted = np.array(pred_list) |
| 209 | + acc = accuracy_score(labels, predicted) |
| 210 | + print(f"model accuracy: {acc:.4f}") |
| 211 | + metric_dict = {"accuracy": acc} |
| 212 | + return metric_dict |
| 213 | + |
| 214 | + |
| 215 | +def test(num_classes, val_data_loader, device): |
| 216 | + ckpt_fp = os.path.join("logs", "checkpoints", "best.pth") |
| 217 | + mod = DenseNet(num_classes=num_classes) |
| 218 | + mod = load_ckpt(ckpt_fp, mod).eval() |
| 219 | + new_runner = validate_model(mod, val_data_loader, device) |
| 220 | + |
| 221 | + |
| 222 | +def main(run_test=True): |
| 223 | + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| 224 | + |
| 225 | + set_global_seed(42) |
| 226 | + prepare_cudnn(True) |
| 227 | + |
| 228 | + transform_basic, transform_augment = get_transforms() |
| 229 | + train_dataset = datasets.CIFAR10("./cifar10", train=True, download=True, transform=transform_augment) |
| 230 | + valid_dataset = datasets.CIFAR10("./cifar10", train=False, download=True, transform=transform_basic) |
| 231 | + |
| 232 | + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) |
| 233 | + valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False) |
| 234 | + loaders = { |
| 235 | + "train": train_loader, |
| 236 | + "valid": valid_loader, |
| 237 | + } |
| 238 | + |
| 239 | + runner = dl.SupervisedRunner( |
| 240 | + input_key="img", output_key="logits", target_key="targets", loss_key="loss" |
| 241 | + ) |
| 242 | + |
| 243 | + num_classes = len(train_dataset.classes) |
| 244 | + model = DenseNet(num_classes) |
| 245 | + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
| 246 | + criterion = torch.nn.CrossEntropyLoss() |
| 247 | + |
| 248 | + runner.train( |
| 249 | + model=model, |
| 250 | + criterion=criterion, |
| 251 | + optimizer=optimizer, |
| 252 | + loaders=loaders, |
| 253 | + loggers={"console": CustomLogger()}, |
| 254 | + num_epochs=6, |
| 255 | + callbacks=[ |
| 256 | + dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)), |
| 257 | + ], |
| 258 | + logdir="./logs", |
| 259 | + valid_loader="valid", |
| 260 | + valid_metric="loss", |
| 261 | + minimize_valid_metric=True, |
| 262 | + verbose=True, |
| 263 | + load_best_on_end=True, |
| 264 | + ) |
| 265 | + |
| 266 | + if run_test: |
| 267 | + test(num_classes, loaders["valid"], device) |
| 268 | + |
| 269 | + |
| 270 | +if __name__ == "__main__": |
| 271 | + main() |
0 commit comments