Skip to content

Commit

Permalink
alpha version of complete model
Browse files Browse the repository at this point in the history
  • Loading branch information
conradry committed Feb 8, 2021
1 parent 0ef0d8a commit 25fda44
Show file tree
Hide file tree
Showing 7 changed files with 625 additions and 20 deletions.
49 changes: 39 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ This repository is under active development. Currently, only the MaX-DeepLab-S a
- [x] MaX-DeepLab-S architecture
- [x] Hungarian Matcher
- [x] PQ-style loss
- [ ] Auxiliary losses (Instance discrimination, Mask-ID cross-entropy, Semantic Segmentation)
- [x] Auxiliary losses (Instance discrimination, Mask-ID cross-entropy, Semantic Segmentation)
- [ ] Optimize model runtime
- [ ] Encoder pre-training on ImageNet
- [ ] MaX-DeepLab-S training on COCO Panoptic
Expand All @@ -27,23 +27,52 @@ MaX-DeepLab has a complex architecture, training procedure and loss function. An

```python
from max_deeplab.model import MaXDeepLabS
from max_deeplab.losses import MaXDeepLabLoss
from datasets.coco_panoptic import build

model = MaXDeepLabS(im_size=640, n_classes=80)
config = {}
config['image_size'] = (640, 640)
config['coco_path'] = '../../datasets/coco_panoptic/' #directory with image and annotation data
data = build('train', config)

P = torch.randn((4, 3, 640, 640))
M = torch.randn((50, 4, 256))
#create a dataloader that collates batch with padding
#see utils.misc.collate_fn
padding_dict = {'image': 0, 'masks': 0, 'semantic_mask': 0, 'labels': 201, 'image_id': 0} #201 is 'no_class' class
sizes_dict = {'image': None, 'masks': 128, 'semantic_mask': None, 'labels': 128, 'image_id': None}
collate_lambda = lambda b: collate_fn(b, padding_dict, sizes_dict)
loader = DataLoader(data, batch_size=8, shuffle=True, collate_fn=collate_lambda)

mask_out, classes = model(P, M)
print(mask_out.shape, classes.shape)
>>> (torch.Size([4, 50, 640, 640]), torch.Size([4, 50, 80]))
batch = iter(loader).next()

#returns a dictionary of NestedTensors (each has a 'tensors' and 'sizes' attribute)
#'sizes' is the number of ground truth masks for an image that are not from padding
print(batch['image'].tensors.size(), batch['masks'].tensors.size(),
batch['labels'].tensors.size(), batch['semantic_mask'].tensors.size())
>>> (torch.Size([8, 3, 640, 640]), torch.Size([8, 128, 640, 640]), torch.Size([8, 128]), torch.Size([8, 640, 640]))

model = MaXDeepLabS(im_size=640, n_classes=202, n_masks=128)
criterion = MaXDeepLabLoss()

num_params = []
for pn, p in model.named_parameters():
num_params.append(np.prod(p.size()))

print(f'{sum(num_params):,} total parameters.')
>>> 61,849,316 total parameters.
>>> 61,873,172 total parameters.

P = batch['image']
M = torch.randn((128, 8, 256))

mask_out, classes, semantic = model(P, M)
print(mask_out.shape, classes.shape, semantic.shape)
>>> (torch.Size([8, 128, 640, 640]), torch.Size([8, 128, 202], torch.Size([8, 202, 640, 640])))

loss = criterion((mask_out, classes, semantic), (batch['masks'], batch['labels'], batch['semantic_mask']))
print(loss) #returns loss value and a dict of loss items for each loss
>>> (tensor(-4.5043),
{'pq': 0.0689038634300232,
'instdisc': -11.75303840637207,
'maskid': 5.25504732131958,
'semantic': 5.465500354766846})
```

(Reported number of parameters in the paper is 61.9M)
(Note: Reported number of parameters in the paper is 61.9M)
57 changes: 54 additions & 3 deletions datasets/coco_panoptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import numpy as np
import torch
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from panopticapi.utils import rgb2id
from PIL import Image


class CocoPanoptic:
def __init__(self, img_folder, ann_folder, ann_file, transforms=None):
with open(ann_file, 'r') as f:
Expand Down Expand Up @@ -59,7 +60,57 @@ def __getitem__(self, idx):
if self.transforms is not None:
output = self.transforms(**output)

output['image_id'] = np.array([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]])
output['labels'] = labels
output['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]])
output['labels'] = torch.tensor(labels, dtype=torch.long)
output['masks'] = torch.from_numpy(np.stack(output['masks'], axis=0))

#create semantic segmentation mask
output['semantic_mask'] = torch.einsum('nhw,n->hw', output['masks'].long(), output['labels'])

return output

def make_coco_transforms(image_set='train', image_size=(640, 640)):
normalize = A.Sequential([
A.Normalize(), #imagenet norms default
ToTensorV2()
])

if image_set == 'train':
transforms = A.Compose([
A.PadIfNeeded(*image_size, border_mode=0), #pad with zeros
A.RandomResizedCrop(*image_size),
A.HorizontalFlip(),
normalize
])

elif image_set == 'val':
transforms = A.Compose([
A.Resize(*image_size),
normalize
])

else:
raise ValueError(f'{image_set} not recognized!')

return transforms

def build(image_set, config):
img_folder_root = config['coco_path']
ann_folder_root = os.path.join(img_folder_root, 'annotations')
assert os.path.exists(img_folder_root), f'provided COCO path {img_folder_root} does not exist'
assert os.path.exists(ann_folder_root), f'provided COCO path {ann_folder_root} does not exist'

PATHS = {
"train": ("train2017", 'panoptic_train2017.json'),
"val": ("val2017", 'panoptic_val2017.json'),
}

img_folder, ann_file = PATHS[image_set]
img_folder_path = os.path.join(img_folder_root, img_folder)
ann_folder = os.path.join(ann_folder_root, f'panoptic_{img_folder}')
ann_file = os.path.join(ann_folder_root, ann_file)

dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file,
transforms=make_coco_transforms(image_set, config['image_size']))

return dataset
3 changes: 2 additions & 1 deletion max_deeplab/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ class MaskHead(nn.Module):
def __init__(
self,
nplanes,
nout,
kernel_size=5,
padding=2,
separable=True
Expand All @@ -434,7 +435,7 @@ def __init__(
self.conv5x5 = conv_bn_relu(
nplanes, nplanes, kernel_size, padding=padding, groups=groups
)
self.conv1x1 = conv_bn_relu(nplanes, nplanes, 1, with_relu=False)
self.conv1x1 = conv_bn_relu(nplanes, nout, 1, with_relu=False)

def forward(self, x):
return self.conv1x1(self.conv5x5(x))
64 changes: 63 additions & 1 deletion max_deeplab/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, alpha=0.75, eps=1e-5, no_class_index=-1):
self.no_class_index = no_class_index
self.matcher = HungarianMatcher()

def forward(self, input_class, input_mask, target_class, target_mask, target_sizes):
def forward(self, input_mask, input_class, target_mask, target_class, target_sizes):
"""
input_class: (B, N, N_CLASSES) #logits
input mask: (B, N, H, W) #probabilities [0, 1]
Expand All @@ -121,6 +121,7 @@ def forward(self, input_class, input_mask, target_class, target_mask, target_siz
"""
#apply softmax to get probabilities from logits
B, N, num_classes = input_class.size()
input_mask = F.softmax(input_mask, dim=1)
input_class_prob = F.softmax(input_class, dim=-1)
input_mask = rearrange(input_mask, 'b n h w -> b n (h w)')
target_mask = rearrange(target_mask, 'b k h w -> b k (h w)')
Expand Down Expand Up @@ -186,6 +187,10 @@ def forward(self, mask_features, target_mask, target_sizes):
target_mask: (B, K, H, W) #m
"""

#downsample input and target by 4 to get (B, H/4, W/4)
mask_features = mask_features[..., ::4, ::4]
target_mask = target_mask[..., ::4, ::4]

device = mask_features.device

#eqn 16
Expand Down Expand Up @@ -221,6 +226,18 @@ def forward(self, mask_features, target_mask, target_sizes):
denominator = logits.sum(-1)
return -torch.log(numerator + self.eps / denominator + self.eps).mean()

class MaskIDLoss(nn.Module):
def __init__(self):
super(MaskIDLoss, self).__init__()
self.xentropy = nn.CrossEntropyLoss()

def forward(self, input_mask, target_mask):
"""
input_mask: (B, N, H, W) #logits
target_mask: (B, H, W) #long indices of maskID in N
"""
return self.xentropy(input_mask, target_mask)

class SemanticSegmentationLoss(nn.Module):
def __init__(self, method='cross_entropy'):
super(SemanticSegmentationLoss, self).__init__()
Expand All @@ -239,3 +256,48 @@ def forward(self, input_mask, target_mask):
target_mask: (B, H, W) #long indices
"""
return self.xentropy(input_mask, target_mask)

class MaXDeepLabLoss(nn.Module):
def __init__(
self,
pq_loss_weight=3,
instance_loss_weight=1,
maskid_loss_weight=0.3,
semantic_loss_weight=1,
alpha=0.75,
temp=0.3,
no_class_index=-1,
eps=1e-5,
):
super(MaXDeepLabLoss, self).__init__()
self.pqw = pq_loss_weight
self.idw = instance_loss_weight
self.miw = maskid_loss_weight
self.ssw = semantic_loss_weight
self.pq_loss = PQLoss(alpha, eps, no_class_index)
self.instance_loss = InstanceDiscLoss(temp, eps)
self.maskid_loss = MaskIDLoss()
self.semantic_loss = SemanticSegmentationLoss()

def forward(self, input_tuple, target_tuple):
"""
input_tuple: (input_masks, input_classes, input_semantic_segmentation) Tensors
target_tuple: (gt_masks, gt_classes, gt_semantic_segmentation) NestedTensors
"""
input_masks, input_classes, input_ss = input_tuple
gt_masks, gt_classes, gt_ss = [t.tensors for t in target_tuple]
target_sizes = target_tuple[0].sizes

pq = self.pq_loss(input_masks, input_classes, gt_masks, gt_classes, target_sizes)
instdisc = self.instance_loss(input_masks, gt_masks.float(), target_sizes)

#create the mask for maskid loss using argmax on ground truth
maskid = self.maskid_loss(input_masks, gt_masks.argmax(1))
semantic = self.semantic_loss(input_ss, gt_ss)

loss_items = {'pq': pq.item(), 'instdisc': instdisc.item(),
'maskid': maskid.item(), 'semantic': semantic.item()}

total_loss = self.pqw * pq + self.idw * instdisc + self.miw * maskid + self.ssw * semantic

return total_loss, loss_items
24 changes: 19 additions & 5 deletions max_deeplab/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,16 @@ def __init__(
)

nin_pixel = nin_pixel // 2
self.mask_head = MaskHead(nin_pixel)
self.mask_head = MaskHead(nin_pixel, n_masks)

self.mem_mask = nn.Sequential(
linear_bn_relu(nin_memory, nin_memory),
linear_bn_relu(nin_memory, nin_pixel, with_relu=False)
linear_bn_relu(nin_memory, n_masks, with_relu=False)
)

self.mem_class = nn.Sequential(
linear_bn_relu(nin_memory, nin_memory),
nn.Linear(nin_memory, nin_pixel)
nn.Linear(nin_memory, n_classes)
)

self.fg_bn = nn.BatchNorm2d(n_masks)
Expand Down Expand Up @@ -214,12 +215,25 @@ def __init__(
n_classes=n_classes, n_masks=n_masks
)

self.semantic_head = nn.Sequential(
conv_bn_relu(2048, 256, 5, padding=2, groups=256),
conv_bn_relu(256, n_classes, 1, with_bn=False, with_relu=False),
nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
)

def forward(self, x):
return self.conv1x1(self.conv5x5(x))

def forward(self, P, M):
"""
P: pixel tensor (B, 3, H, W)
P: pixel NestedTensor (B, 3, H, W)
M: memory tensor (N, B, 256)
"""

#P is a nested tensor, extract the image data
#see utils.misc.NestedTensor
P, mask = P.decompose()
fmaps, mem = self.encoder(P, M)
semantic_mask = self.semantic_head(fmaps[-1])
mask_out, classes = self.decoder(fmaps, mem)
return mask_out, classes
return mask_out, classes, semantic_mask
Empty file added util/__init__.py
Empty file.
Loading

0 comments on commit 25fda44

Please sign in to comment.