Skip to content

Commit

Permalink
visualize masked images
Browse files Browse the repository at this point in the history
  • Loading branch information
youweiliang committed Apr 29, 2022
1 parent c988a40 commit 180a7a3
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 12 deletions.
56 changes: 55 additions & 1 deletion engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# All rights reserved.
# ------------------------------------------
# Modification:
# Added code for adjusting keep rate -- Youwei Liang
# Added code for adjusting keep rate and visualization -- Youwei Liang
"""
Train and eval functions used in main.py
"""
Expand All @@ -19,6 +19,8 @@
import utils

from helpers import adjust_keep_rate
from visualize_mask import get_real_idx, mask, save_img_batch
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
Expand Down Expand Up @@ -150,3 +152,55 @@ def get_acc(data_loader, model, device, keep_rate=None, tokens=None):
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

return metric_logger.acc1.global_avg


@torch.no_grad()
def visualize_mask(data_loader, model, device, output_dir, n_visualization, keep_rate=None):
criterion = torch.nn.CrossEntropyLoss()

metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Visualize:'
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
mean = torch.tensor(IMAGENET_DEFAULT_MEAN, device=device).reshape(3, 1, 1)
std = torch.tensor(IMAGENET_DEFAULT_STD, device=device).reshape(3, 1, 1)

# switch to evaluation mode
model.eval()

ii = 0
for images, target in metric_logger.log_every(data_loader, 10, header):
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
B = images.size(0)

with torch.cuda.amp.autocast():
output, idx = model(images, keep_rate, get_idx=True)
loss = criterion(output, target)

acc1, acc5 = accuracy(output, target, topk=(1, 5))

# denormalize
images = images * std + mean

idxs = get_real_idx(idx)
for jj, idx in enumerate(idxs):
masked_img = mask(images, patch_size=16, idx=idx)
save_img_batch(masked_img, output_dir, file_name='img_{}' + f'_l{jj}.jpg', start_idx=world_size * B * ii + rank * B)

save_img_batch(images, output_dir, file_name='img_{}_a.jpg', start_idx=world_size * B * ii + rank * B)

batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.synchronize_between_processes()
ii += 1
if world_size * B * ii >= n_visualization:
break

metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
31 changes: 21 additions & 10 deletions evit.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_d
self.mlp_hidden_dim = mlp_hidden_dim
self.fuse_token = fuse_token

def forward(self, x, keep_rate=None, tokens=None):
def forward(self, x, keep_rate=None, tokens=None, get_idx=False):
if keep_rate is None:
keep_rate = self.keep_rate # this is for inference, use the default keep rate
B, N, C = x.shape
Expand All @@ -267,7 +267,14 @@ def forward(self, x, keep_rate=None, tokens=None):

x = x + self.drop_path(self.mlp(self.norm2(x)))
n_tokens = x.shape[1] - 1
return x, n_tokens
if get_idx and index is not None:
idx = idx[:, :, 0]
# if self.fuse_token:
# # always set the idx of the extra token to 0
# B = idx.size(0)
# idx = torch.cat([idx, torch.zeros(B, 1, dtype=idx.dtype, device=idx.device)], dim=1) # [B, M]
return x, n_tokens, idx
return x, n_tokens, None


class EViT(nn.Module):
Expand Down Expand Up @@ -391,7 +398,7 @@ def reset_classifier(self, num_classes, global_pool=''):
def name(self):
return "EViT"

def forward_features(self, x, keep_rate=None, tokens=None):
def forward_features(self, x, keep_rate=None, tokens=None, get_idx=False):
_, _, h, w = x.shape
if not isinstance(keep_rate, (tuple, list)):
keep_rate = (keep_rate, ) * self.depth
Expand Down Expand Up @@ -422,17 +429,21 @@ def forward_features(self, x, keep_rate=None, tokens=None):
x = self.pos_drop(x + pos_embed)

left_tokens = []
if get_idx:
idxs = []
for i, blk in enumerate(self.blocks):
x, left_token = blk(x, keep_rate[i], tokens[i])
x, left_token, idx = blk(x, keep_rate[i], tokens[i], get_idx)
left_tokens.append(left_token)
if idx is not None:
idxs.append(idx)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0]), left_tokens
return self.pre_logits(x[:, 0]), left_tokens, idxs
else:
return x[:, 0], x[:, 1]
return x[:, 0], x[:, 1], idxs

def forward(self, x, keep_rate=None, tokens=None, speed_test=False):
x, left_tokens = self.forward_features(x, keep_rate, tokens)
def forward(self, x, keep_rate=None, tokens=None, get_idx=False):
x, _, idxs = self.forward_features(x, keep_rate, tokens, get_idx)
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
Expand All @@ -442,8 +453,8 @@ def forward(self, x, keep_rate=None, tokens=None, speed_test=False):
return (x + x_dist) / 2
else:
x = self.head(x)
if speed_test:
return x, left_tokens
if get_idx:
return x, idxs
return x


Expand Down
8 changes: 7 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from timm.utils import NativeScaler, get_state_dict, ModelEma

from datasets import build_dataset
from engine import train_one_epoch, evaluate
from engine import train_one_epoch, evaluate, visualize_mask
from losses import DistillationLoss
from samplers import RASampler
import models
Expand Down Expand Up @@ -177,6 +177,8 @@ def get_args_parser():
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--visualize_mask', action='store_true', help='Visualize the dropped image patches and then exit')
parser.add_argument('--n_visualization', default=128, type=int)
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin-mem', action='store_true',
Expand Down Expand Up @@ -427,6 +429,10 @@ def log_func1(*arg, **kwargs):
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return

if args.visualize_mask:
visualize_mask(data_loader_val, model, device, args.output_dir, args.n_visualization)
return

print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
Expand Down
51 changes: 51 additions & 0 deletions visualize_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import torch
from torchvision.utils import save_image
from einops import rearrange


def mask(x, idx, patch_size):
"""
Args:
x: input image, shape: [B, 3, H, W]
idx: indices of masks, shape: [B, T], value in range [0, h*w)
Return:
out_img: masked image with only patches from idx postions
"""
h = x.size(2) // patch_size
x = rearrange(x, 'b c (h p) (w q) -> b (c p q) (h w)', p=patch_size, q=patch_size)
output = torch.zeros_like(x)
idx1 = idx.unsqueeze(1).expand(-1, x.size(1), -1)
extracted = torch.gather(x, dim=2, index=idx1) # [b, c p q, T]
scattered = torch.scatter(output, dim=2, index=idx1, src=extracted)
out_img = rearrange(scattered, 'b (c p q) (h w) -> b c (h p) (w q)', p=patch_size, q=patch_size, h=h)
return out_img


def get_deeper_idx(idx1, idx2):
"""
Args:
idx1: indices, shape: [B, T1]
idx2: indices to gather from idx1, shape: [B, T2], T2 <= T1
"""
return torch.gather(idx1, dim=1, index=idx2)


def get_real_idx(idxs, img_size=224, patch_size=16):
# nh = img_size // patch_size
# npatch = nh ** 2
for i in range(len(idxs)):
idxs[i] = idxs[i][:, 1:] - 1 # remove cls token idx

# gather real idx
for i in range(1, len(idxs)):
tmp = idxs[i - 1]
B = tmp.size(0)
tmp = torch.cat([tmp, torch.zeros(B, 1, dtype=tmp.dtype, device=tmp.device)], dim=1)
idxs[i] = torch.gather(tmp, dim=1, index=idxs[i])
return idxs


def save_img_batch(x, path, file_name='img{}', start_idx=0):
for i, img in enumerate(x):
save_image(img, os.path.join(path, file_name.format(start_idx + i)))

0 comments on commit 180a7a3

Please sign in to comment.