Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
models
myenv
27 changes: 25 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,34 @@ We provide a simply script to get the visualization result on the CIHP dataset u
# Example of inference
python exp/inference/inference.py \
--loadmodel /path_to_inference_model \
--img_path ./img/messi.jpg \
--img_path ./img/ronaldo.jpg \
--output_path ./img/ \
--output_name /output_file_name
--output_name ronaldo_output
```

💻 **Note for macOS users (Apple M1/M2/M3):**
```shell
# Example of inference
python exp/inference/inference_mac.py \
--loadmodel /path_to_inference_model \
--img_path ./img/ronaldo.jpg \
--output_path ./img/ \
--output_name ronaldo_output
```
This script uses **PyTorch MPS (Metal Performance Shaders)** for GPU acceleration on macOS.


### Blend (Overlay) Result
After running inference, you can visualize the segmentation result by blending the original image and the mask together.
```shell
# Example of BLEND
python exp/inference/inference_image_blend.py \
--img_path ./img/ronaldo.jpg \
--mask_path ./img/ronaldo_output.png \
--output_path ./img/ \
--output_name ronaldo_blend.png
```

### Training
#### Transfer learning
1. Download the Pascal pretrained model(available soon).
Expand Down
2 changes: 2 additions & 0 deletions dataloaders/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__
/__pycache__
58 changes: 58 additions & 0 deletions exp/inference/inference_image_blend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import cv2
import os
import argparse

def create_overlay(img_path, mask_path, output_path, output_name):
"""
Combines an original image and its parsed mask into a blended overlay.
Args:
img_path (str): Full path to the original image.
mask_path (str): Full path to the parsed mask image.
output_path (str): Directory where the overlay will be saved.
output_name (str): Filename for the final overlay result.
"""

print("Starting overlay process...")
print(f"Original: {img_path}")
print(f"Mask: {mask_path}")

# Load images
original = cv2.imread(img_path)
mask = cv2.imread(mask_path)

# Validate
if original is None or mask is None:
raise FileNotFoundError(f"Could not load image or mask.\noriginal={img_path}\nmask={mask_path}")

# Resize mask to match original
mask = cv2.resize(mask, (original.shape[1], original.shape[0]))

# Blend both (semi-transparent)
overlay = cv2.addWeighted(original, 0.6, mask, 0.4, 0)

# Create output directory if needed
os.makedirs(output_path, exist_ok=True)

# Save overlay
output_file = os.path.join(output_path, output_name)
cv2.imwrite(output_file, overlay)

print(f"Overlay saved to: {output_file}")
return output_file



parser = argparse.ArgumentParser(description="Overlay original image and its parsed mask")
parser.add_argument("--img_path", required=True, help="Path to the original image")
parser.add_argument("--mask_path", required=True, help="Path to the mask image")
parser.add_argument("--output_path", required=True, help="Directory for saving output")
parser.add_argument("--output_name", required=True, help="Filename for the saved overlay")

args = parser.parse_args()

create_overlay(
img_path=args.img_path,
mask_path=args.mask_path,
output_path=args.output_path,
output_name=args.output_name
)
188 changes: 188 additions & 0 deletions exp/inference/inference_mac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import socket
import timeit
import numpy as np
from PIL import Image
from datetime import datetime
import os
import sys
from collections import OrderedDict
sys.path.append('./')

import torch
from torch.autograd import Variable
from torchvision import transforms
import cv2

from networks import deeplab_xception_transfer, graph
from dataloaders import custom_transforms as tr

import argparse
import torch.nn.functional as F

# ---------- Device selection: MPS → CPU ----------
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Using Apple Metal (MPS)")
else:
device = torch.device("cpu")
print("Using CPU")

label_colours = [(0,0,0)
, (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0)
, (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)]

def flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
dtype=torch.long, device=x.device)
return x[tuple(indices)]

def flip_cihp(tail_list):
'''

:param tail_list: tail_list size is 1 x n_class x h x w
:return:
'''
# tail_list = tail_list[0]
tail_list_rev = [None] * 20
for xx in range(14):
tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
tail_list_rev[14] = tail_list[15].unsqueeze(0)
tail_list_rev[15] = tail_list[14].unsqueeze(0)
tail_list_rev[16] = tail_list[17].unsqueeze(0)
tail_list_rev[17] = tail_list[16].unsqueeze(0)
tail_list_rev[18] = tail_list[19].unsqueeze(0)
tail_list_rev[19] = tail_list[18].unsqueeze(0)
return torch.cat(tail_list_rev,dim=0)

def decode_labels(mask, num_images=1, num_classes=20):
"""Decode batch of segmentation masks.

Args:
mask: result of inference after taking argmax.
num_images: number of images to decode from the batch.
num_classes: number of classes to predict (including background).

Returns:
A batch with num_images RGB images of the same size as the input.
"""
n, h, w = mask.shape
assert (n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (
n, num_images)
outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
for i in range(num_images):
img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
pixels = img.load()
for j_, j in enumerate(mask[i, :, :]):
for k_, k in enumerate(j):
if k < num_classes:
pixels[k_, j_] = label_colours[k]
outputs[i] = np.array(img)
return outputs

def read_img(img_path):
_img = Image.open(img_path).convert('RGB') # return is RGB pic
return _img

def img_transform(img, transform=None):
sample = {'image': img, 'label': 0}

sample = transform(sample)
return sample

@torch.no_grad()
def inference(net, img_path='', output_path='./', output_name='f'):
# ----- build adjacencies on correct device -----
adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float().to(device)
adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1,1,7,20).transpose(2,3)

adj1_ = torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float().to(device)
adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1,1,7,7)

cihp_adj = graph.preprocess_adj(graph.cihp_graph)
adj3_ = torch.from_numpy(cihp_adj).float().to(device)
adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1,1,20,20)

# ----- multi-scale -----
scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
img = read_img(img_path)
testloader_list, testloader_flip_list = [], []

for pv in scale_list:
composed = transforms.Compose([
tr.Scale_only_img(pv),
tr.Normalize_xception_tf_only_img(),
tr.ToTensor_only_img()
])
composed_flip = transforms.Compose([
tr.Scale_only_img(pv),
tr.HorizontalFlip_only_img(),
tr.Normalize_xception_tf_only_img(),
tr.ToTensor_only_img()
])
testloader_list.append(img_transform(img, composed))
testloader_flip_list.append(img_transform(img, composed_flip))

start_time = timeit.default_timer()
net.eval()
outputs_final = None

for iii, (sb, sb_flip) in enumerate(zip(testloader_list, testloader_flip_list)):
inputs = sb['image'].unsqueeze(0).to(device)
inputs_f = sb_flip['image'].unsqueeze(0).to(device)
inputs = torch.cat((inputs, inputs_f), dim=0)

if iii == 0:
_, _, h, w = inputs.size()

# forward
outputs = net.forward(inputs, adj1_test, adj3_test, adj2_test)
# TTA: average original + flipped-back
outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
outputs = outputs.unsqueeze(0)

if iii > 0:
outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True)
outputs_final = outputs_final + outputs
else:
outputs_final = outputs.clone()

predictions = torch.max(outputs_final, 1)[1]
results = predictions.detach().cpu().numpy()
vis_res = decode_labels(results)

os.makedirs(output_path, exist_ok=True)
Image.fromarray(vis_res[0]).save(os.path.join(output_path, f'{output_name}.png'))
cv2.imwrite(os.path.join(output_path, f'{output_name}_gray.png'), results[0, :, :])

print('time used for the multi-scale image inference:', timeit.default_timer() - start_time)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--loadmodel', default='', type=str)
parser.add_argument('--img_path', default='', type=str)
parser.add_argument('--output_path', default='', type=str)
parser.add_argument('--output_name', default='', type=str)
args = parser.parse_args()

net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(
n_classes=20, hidden_layers=128, source_classes=7
)

if args.loadmodel:
state = torch.load(args.loadmodel, map_location=device)
net.load_source_model(state)
print('Loaded model:', args.loadmodel)
else:
raise RuntimeError('No model supplied via --loadmodel')

net.to(device).eval()

if not args.img_path:
raise RuntimeError('Provide --img_path')
if not args.output_path:
args.output_path = './outputs'
if not args.output_name:
args.output_name = os.path.splitext(os.path.basename(args.img_path))[0]

inference(net=net, img_path=args.img_path, output_path=args.output_path, output_name=args.output_name)
Binary file added img/ronaldo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/ronaldo_blend.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/ronaldo_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/ronaldo_output_gray.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions networks/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__
/__pycache__
2 changes: 2 additions & 0 deletions sync_batchnorm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__
/__pycache__