diff --git a/engine.py b/engine.py index 9d1f830..258084e 100644 --- a/engine.py +++ b/engine.py @@ -2,7 +2,7 @@ # All rights reserved. # ------------------------------------------ # Modification: -# Added code for adjusting keep rate -- Youwei Liang +# Added code for adjusting keep rate and visualization -- Youwei Liang """ Train and eval functions used in main.py """ @@ -19,6 +19,8 @@ import utils from helpers import adjust_keep_rate +from visualize_mask import get_real_idx, mask, save_img_batch +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, @@ -150,3 +152,55 @@ def get_acc(data_loader, model, device, keep_rate=None, tokens=None): .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) return metric_logger.acc1.global_avg + + +@torch.no_grad() +def visualize_mask(data_loader, model, device, output_dir, n_visualization, keep_rate=None): + criterion = torch.nn.CrossEntropyLoss() + + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Visualize:' + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + mean = torch.tensor(IMAGENET_DEFAULT_MEAN, device=device).reshape(3, 1, 1) + std = torch.tensor(IMAGENET_DEFAULT_STD, device=device).reshape(3, 1, 1) + + # switch to evaluation mode + model.eval() + + ii = 0 + for images, target in metric_logger.log_every(data_loader, 10, header): + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + B = images.size(0) + + with torch.cuda.amp.autocast(): + output, idx = model(images, keep_rate, get_idx=True) + loss = criterion(output, target) + + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + # denormalize + images = images * std + mean + + idxs = get_real_idx(idx) + for jj, idx in enumerate(idxs): + masked_img = mask(images, patch_size=16, idx=idx) + save_img_batch(masked_img, output_dir, file_name='img_{}' + f'_l{jj}.jpg', start_idx=world_size * B * ii + rank * B) + + save_img_batch(images, output_dir, file_name='img_{}_a.jpg', start_idx=world_size * B * ii + rank * B) + + batch_size = images.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + metric_logger.synchronize_between_processes() + ii += 1 + if world_size * B * ii >= n_visualization: + break + + metric_logger.synchronize_between_processes() + print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' + .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/evit.py b/evit.py index d07acbb..75dfbfc 100644 --- a/evit.py +++ b/evit.py @@ -242,7 +242,7 @@ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_d self.mlp_hidden_dim = mlp_hidden_dim self.fuse_token = fuse_token - def forward(self, x, keep_rate=None, tokens=None): + def forward(self, x, keep_rate=None, tokens=None, get_idx=False): if keep_rate is None: keep_rate = self.keep_rate # this is for inference, use the default keep rate B, N, C = x.shape @@ -267,7 +267,14 @@ def forward(self, x, keep_rate=None, tokens=None): x = x + self.drop_path(self.mlp(self.norm2(x))) n_tokens = x.shape[1] - 1 - return x, n_tokens + if get_idx and index is not None: + idx = idx[:, :, 0] + # if self.fuse_token: + # # always set the idx of the extra token to 0 + # B = idx.size(0) + # idx = torch.cat([idx, torch.zeros(B, 1, dtype=idx.dtype, device=idx.device)], dim=1) # [B, M] + return x, n_tokens, idx + return x, n_tokens, None class EViT(nn.Module): @@ -391,7 +398,7 @@ def reset_classifier(self, num_classes, global_pool=''): def name(self): return "EViT" - def forward_features(self, x, keep_rate=None, tokens=None): + def forward_features(self, x, keep_rate=None, tokens=None, get_idx=False): _, _, h, w = x.shape if not isinstance(keep_rate, (tuple, list)): keep_rate = (keep_rate, ) * self.depth @@ -422,17 +429,21 @@ def forward_features(self, x, keep_rate=None, tokens=None): x = self.pos_drop(x + pos_embed) left_tokens = [] + if get_idx: + idxs = [] for i, blk in enumerate(self.blocks): - x, left_token = blk(x, keep_rate[i], tokens[i]) + x, left_token, idx = blk(x, keep_rate[i], tokens[i], get_idx) left_tokens.append(left_token) + if idx is not None: + idxs.append(idx) x = self.norm(x) if self.dist_token is None: - return self.pre_logits(x[:, 0]), left_tokens + return self.pre_logits(x[:, 0]), left_tokens, idxs else: - return x[:, 0], x[:, 1] + return x[:, 0], x[:, 1], idxs - def forward(self, x, keep_rate=None, tokens=None, speed_test=False): - x, left_tokens = self.forward_features(x, keep_rate, tokens) + def forward(self, x, keep_rate=None, tokens=None, get_idx=False): + x, _, idxs = self.forward_features(x, keep_rate, tokens, get_idx) if self.head_dist is not None: x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple if self.training and not torch.jit.is_scripting(): @@ -442,8 +453,8 @@ def forward(self, x, keep_rate=None, tokens=None, speed_test=False): return (x + x_dist) / 2 else: x = self.head(x) - if speed_test: - return x, left_tokens + if get_idx: + return x, idxs return x diff --git a/main.py b/main.py index 6065769..b4c6db1 100644 --- a/main.py +++ b/main.py @@ -23,7 +23,7 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma from datasets import build_dataset -from engine import train_one_epoch, evaluate +from engine import train_one_epoch, evaluate, visualize_mask from losses import DistillationLoss from samplers import RASampler import models @@ -177,6 +177,8 @@ def get_args_parser(): parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') + parser.add_argument('--visualize_mask', action='store_true', help='Visualize the dropped image patches and then exit') + parser.add_argument('--n_visualization', default=128, type=int) parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin-mem', action='store_true', @@ -427,6 +429,10 @@ def log_func1(*arg, **kwargs): print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") return + if args.visualize_mask: + visualize_mask(data_loader_val, model, device, args.output_dir, args.n_visualization) + return + print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 diff --git a/visualize_mask.py b/visualize_mask.py new file mode 100644 index 0000000..35447de --- /dev/null +++ b/visualize_mask.py @@ -0,0 +1,51 @@ +import os +import torch +from torchvision.utils import save_image +from einops import rearrange + + +def mask(x, idx, patch_size): + """ + Args: + x: input image, shape: [B, 3, H, W] + idx: indices of masks, shape: [B, T], value in range [0, h*w) + Return: + out_img: masked image with only patches from idx postions + """ + h = x.size(2) // patch_size + x = rearrange(x, 'b c (h p) (w q) -> b (c p q) (h w)', p=patch_size, q=patch_size) + output = torch.zeros_like(x) + idx1 = idx.unsqueeze(1).expand(-1, x.size(1), -1) + extracted = torch.gather(x, dim=2, index=idx1) # [b, c p q, T] + scattered = torch.scatter(output, dim=2, index=idx1, src=extracted) + out_img = rearrange(scattered, 'b (c p q) (h w) -> b c (h p) (w q)', p=patch_size, q=patch_size, h=h) + return out_img + + +def get_deeper_idx(idx1, idx2): + """ + Args: + idx1: indices, shape: [B, T1] + idx2: indices to gather from idx1, shape: [B, T2], T2 <= T1 + """ + return torch.gather(idx1, dim=1, index=idx2) + + +def get_real_idx(idxs, img_size=224, patch_size=16): + # nh = img_size // patch_size + # npatch = nh ** 2 + for i in range(len(idxs)): + idxs[i] = idxs[i][:, 1:] - 1 # remove cls token idx + + # gather real idx + for i in range(1, len(idxs)): + tmp = idxs[i - 1] + B = tmp.size(0) + tmp = torch.cat([tmp, torch.zeros(B, 1, dtype=tmp.dtype, device=tmp.device)], dim=1) + idxs[i] = torch.gather(tmp, dim=1, index=idxs[i]) + return idxs + + +def save_img_batch(x, path, file_name='img{}', start_idx=0): + for i, img in enumerate(x): + save_image(img, os.path.join(path, file_name.format(start_idx + i))) \ No newline at end of file