Skip to content

Commit

Permalink
Create test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
QingpingZheng authored Mar 22, 2022
1 parent 9c78982 commit af2d71f
Showing 1 changed file with 182 additions and 0 deletions.
182 changes: 182 additions & 0 deletions parsing/dml_csr/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Qingping Zheng
@Contact : qingpingzheng2014@gmail.com
@File : datasets.py
@Time : 10/01/21 00:00 PM
@Desc :
@License : Licensed under the Apache License, Version 2.0 (the "License");
@Copyright : Copyright 2015 The Authors. All Rights Reserved.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms

from copy import deepcopy
from inplace_abn import InPlaceABN
from dataset import datasets
from networks import dml_csr
from utils import miou

torch.multiprocessing.set_start_method("spawn", force=True)

DATA_DIRECTORY = './datasets/Helen'
IGNORE_LABEL = 255
NUM_CLASSES = 20
SNAPSHOT_DIR = './snapshots/'
INPUT_SIZE = (473,473)


def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="DML_CSR Network")
parser.add_argument("--batch-size", type=int, default=1,
help="Number of images sent to the network in one step.")
parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY,
help="Path to the directory containing the PASCAL VOC dataset.")
parser.add_argument("--out-dir", type=str, default=DATA_DIRECTORY,
help="Path to the directory containing the PASCAL VOC dataset.")
parser.add_argument("--dataset", type=str, default='val',
help="Path to the file listing the images in the dataset.")
parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,
help="The index of the label to ignore during the training.")
parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,
help="Number of classes to predict (including background).")
parser.add_argument("--restore-from", type=str,
help="Where restore model parameters from.")
parser.add_argument("--gpu", type=str, default='7',
help="choose gpu device.")
parser.add_argument("--input-size", type=str, default=INPUT_SIZE,
help="Comma-separated string with height and width of images.")
parser.add_argument("--local_rank", type=int, default=0,
help="choose gpu numbers")
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument("--model_type", type=int, default=0,
help="choose model type")
return parser.parse_args()


def valid(model, valloader, input_size, num_samples, dir=None, dir_edge=None, dir_img=None):

height = input_size[0]
width = input_size[1]
with torch.autograd.profiler.profile(enabled=True, use_cuda=True, \
record_shapes=False, profile_memory=False) as prof:
model.eval()
parsing_preds = np.zeros((num_samples, height, width), dtype=np.uint8)
scales = np.zeros((num_samples, 2), dtype=np.float32)
centers = np.zeros((num_samples, 2), dtype=np.int32)

idx = 0
interp = torch.nn.Upsample(size=(height, width), mode='bilinear', align_corners=True)

with torch.no_grad():
for index, batch in enumerate(valloader):
image, meta = batch
num_images = image.size(0)
if index % 10 == 0:
print('%d processd' % (index * num_images))

c = meta['center'].numpy()
s = meta['scale'].numpy()
scales[idx:idx + num_images, :] = s[:, :]
centers[idx:idx + num_images, :] = c[:, :]

results = model(image.cuda())
outputs = results

if isinstance(results, list):
outputs = results[0]

if isinstance(outputs, list):
for k, output in enumerate(outputs):
parsing = output
nums = len(parsing)
parsing = interp(parsing).data.cpu().numpy()
parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC
parsing_preds[idx:idx + nums, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8)
idx += nums
else:
parsing = outputs
parsing = interp(parsing).data.cpu().numpy()
parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC
parsing_preds[idx:idx + num_images, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8)

if dir is not None:
for i in range(len(meta['name'])):
cv2.imwrite(os.path.join(dir, meta['name'][i] + '.png'), np.asarray(np.argmax(parsing, axis=3))[i])
idx += num_images
parsing_preds = parsing_preds[:num_samples, :, :]

return parsing_preds, scales, centers


def main():
"""Create the model and start the evaluation process."""

args = get_arguments()

os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
gpus = [int(i) for i in args.gpu.split(',')]

print(args.gpu)

h, w = map(int, args.input_size.split(','))

input_size = (h, w)

cudnn.benchmark = True
cudnn.enabled = True

model = dml_csr.DML_CSR(args.num_classes, InPlaceABN, False)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])

dataset = datasets.FaceDataSet(args.data_dir, args.dataset, \
crop_size=input_size, transform=transform)
num_samples = len(dataset)

valloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, \
shuffle=False, pin_memory=True)

restore_from = args.restore_from
print(restore_from)
state_dict = torch.load(restore_from,map_location='cuda:0')
model.load_state_dict(state_dict)

model.cuda()
model.eval()

save_path = os.path.join(args.out_dir, args.dataset, 'parsing')
if not os.path.exists(save_path):
os.makedirs(save_path)

parsing_preds, scales, centers = valid(model, valloader, input_size, num_samples, save_path)
mIoU, f1 = miou.compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, \
input_size, args.dataset, reverse=True)

print(mIoU)
print(f1)

if __name__ == '__main__':
main()

0 comments on commit af2d71f

Please sign in to comment.