From 6a884011fd7c48ff7e479bb9f0a6129888cd5200 Mon Sep 17 00:00:00 2001 From: tzzcl Date: Mon, 2 Mar 2020 02:04:33 -0600 Subject: [PATCH] First commit; add validation code Signed-off-by: tzzcl --- PSOL_inference.py | 426 ++++++++++++++++++++++++++ utils/IoU.py | 106 +++++++ utils/__pycache__/IoU.cpython-36.pyc | Bin 0 -> 2676 bytes utils/__pycache__/func.cpython-36.pyc | Bin 0 -> 1744 bytes utils/__pycache__/nms.cpython-36.pyc | Bin 0 -> 938 bytes utils/__pycache__/vis.cpython-36.pyc | Bin 0 -> 794 bytes utils/func.py | 78 +++++ utils/nms.py | 36 +++ utils/vis.py | 22 ++ 9 files changed, 668 insertions(+) create mode 100644 PSOL_inference.py create mode 100644 utils/IoU.py create mode 100644 utils/__pycache__/IoU.cpython-36.pyc create mode 100644 utils/__pycache__/func.cpython-36.pyc create mode 100644 utils/__pycache__/nms.cpython-36.pyc create mode 100644 utils/__pycache__/vis.cpython-36.pyc create mode 100644 utils/func.py create mode 100644 utils/nms.py create mode 100644 utils/vis.py diff --git a/PSOL_inference.py b/PSOL_inference.py new file mode 100644 index 0000000..7677847 --- /dev/null +++ b/PSOL_inference.py @@ -0,0 +1,426 @@ +import os +import sys +import json +import numpy as np +import torch +import torchvision.transforms as transforms +from torch.backends import cudnn +from torch.autograd import Variable +import torch.nn as nn +import torchvision +import torchvision.models as models +from PIL import Image +from utils.func import * +from copy import deepcopy +from utils.vis import * +from utils.IoU import * +import copy +from torchvision.transforms import functional as F +import numbers +import argparse +def compute_intersec(i, j, h, w, bbox): + ''' + intersection box between croped box and GT BBox + ''' + intersec = copy.deepcopy(bbox) + + intersec[0] = max(j, bbox[0]) + intersec[1] = max(i, bbox[1]) + intersec[2] = min(j + w, bbox[2]) + intersec[3] = min(i + h, bbox[3]) + return intersec + + +def normalize_intersec(i, j, h, w, intersec): + ''' + return: normalize into [0, 1] + ''' + + intersec[0] = (intersec[0] - j) / w + intersec[2] = (intersec[2] - j) / w + intersec[1] = (intersec[1] - i) / h + intersec[3] = (intersec[3] - i) / h + return intersec +class ResizedBBoxCrop(object): + + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + + self.interpolation = interpolation + + @staticmethod + def get_params(img, bbox, size): + #resize to 256 + if isinstance(size, int): + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + img = copy.deepcopy(img) + ow, oh = w, h + if w < h: + ow = size + oh = int(size*h/w) + else: + oh = size + ow = int(size*w/h) + else: + ow, oh = size[::-1] + w, h = img.size + + + intersec = copy.deepcopy(bbox) + ratew = ow / w + rateh = oh / h + intersec[0] = bbox[0]*ratew + intersec[2] = bbox[2]*ratew + intersec[1] = bbox[1]*rateh + intersec[3] = bbox[3]*rateh + + #intersec = normalize_intersec(i, j, h, w, intersec) + return (oh, ow), intersec + + def __call__(self, img, bbox): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + size, crop_bbox = self.get_params(img, bbox, self.size) + return F.resize(img, self.size, self.interpolation), crop_bbox + + +class CenterBBoxCrop(object): + + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + + self.interpolation = interpolation + + @staticmethod + def get_params(img, bbox, size): + #center crop + if isinstance(size, numbers.Number): + output_size = (int(size), int(size)) + + w, h = img.size + th, tw = output_size + + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + + intersec = compute_intersec(i, j, th, tw, bbox) + intersec = normalize_intersec(i, j, th, tw, intersec) + + #intersec = normalize_intersec(i, j, h, w, intersec) + return i, j, th, tw, intersec + + def __call__(self, img, bbox): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, th, tw, crop_bbox = self.get_params(img, bbox, self.size) + return F.center_crop(img, self.size), crop_bbox + +class VGGGAP(nn.Module): + def __init__(self, pretrained=True, num_classes=200): + super(VGGGAP,self).__init__() + self.features = torchvision.models.vgg16(pretrained=pretrained).features + self.avgpool = nn.AdaptiveAvgPool2d((1,1)) + self.classifier = nn.Sequential((nn.Linear(512,512),nn.ReLU(),nn.Linear(512,4),nn.Sigmoid())) + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = x.view(x.size(0),-1) + x = self.classifier(x) + return x +class VGG16(nn.Module): + def __init__(self, pretrained=True, num_classes=200): + super(VGG16,self).__init__() + self.features = torchvision.models.vgg16(pretrained=pretrained).features + temp_classifier = torchvision.models.vgg16(pretrained=pretrained).classifier + removed = list(temp_classifier.children()) + removed = removed[:-1] + temp_layer = nn.Sequential(nn.Linear(4096,512),nn.ReLU(),nn.Linear(512,4),nn.Sigmoid()) + removed.append(temp_layer) + self.classifier = nn.Sequential(*removed) + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0),-1) + x = self.classifier(x) + return x + + +def to_variable(x): + if torch.cuda.is_available(): + x = x.cuda() + return Variable(x) + +def to_data(x): + if torch.cuda.is_available(): + x = x.cpu() + return x.data + +def copy_parameters(model, pretrained_dict): + model_dict = model.state_dict() + + pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict and pretrained_dict[k].size()==model_dict[k[7:]].size()} + #for k, v in pretrained_dict.items(): + # print(k) + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + return model +def choose_locmodel(model_name): + if model_name == 'densenet161': + model = torchvision.models.densenet161(pretrained=True) + + model.classifier = nn.Sequential( + nn.Linear(2208, 512), + nn.ReLU(), + nn.Linear(512, 4), + nn.Sigmoid() + ) + model = copy_parameters(model, torch.load('densenet161loc.pth.tar')) + elif model_name == 'resnet50': + model = torchvision.models.resnet50(pretrained=True, num_classes=1000) + model.fc = nn.Sequential( + nn.Linear(2048, 512), + nn.ReLU(), + nn.Linear(512, 4), + nn.Sigmoid() + ) + model = copy_parameters(model, torch.load('resnet50loc.pth.tar')) + elif model_name == 'vgggap': + model = VGGGAP(pretrained=True,num_classes=1000) + model = copy_parameters(model, torch.load('vgggaploc.pth.tar')) + elif model_name == 'vgg16': + model = VGG16(pretrained=True,num_classes=1000) + model = copy_parameters(model, torch.load('vgg16loc.pth.tar')) + elif model_name == 'inceptionv3': + #need for rollback inceptionv3 official code + pass + else: + raise ValueError('Do not have this model currently!') + return model +def choose_clsmodel(model_name): + if model_name == 'vgg16': + cls_model = torchvision.models.vgg16(pretrained=True) + elif model_name == 'inceptionv3': + cls_model = torchvision.models.inception_v3(pretrained=True, aux_logits=True, transform_input=True) + elif model_name == 'resnet50': + cls_model = torchvision.models.resnet50(pretrained=True) + elif model_name == 'densenet161': + cls_model = torchvision.models.densenet161(pretrained=True) + elif model_name == 'dpn131': + cls_model = torch.hub.load('rwightman/pytorch-dpn-pretrained', 'dpn131', pretrained=True,test_time_pool=True) + elif model_name == 'efficientnetb7': + from efficientnet_pytorch import EfficientNet + cls_model = EfficientNet.from_pretrained('efficientnet-b7') + return cls_model +parser = argparse.ArgumentParser(description='Parameters for PSOL') +parser.add_argument('--loc-model', metavar='locarg', type=str, default='vgg16',dest='locmodel') +parser.add_argument('--cls-model', metavar='locarg', type=str, default='vgg16',dest='clsmodel') +parser.add_argument('--ten-crop', help='tencrop', action='store_true',dest='tencrop') +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = "0" +os.environ['OMP_NUM_THREADS'] = "4" +os.environ['MKL_NUM_THREADS'] = "4" +cudnn.benchmark = True +TEN_CROP = args.tencrop +normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) +transform = transforms.Compose([ + transforms.Resize((256,256)), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize +]) +cls_transform = transforms.Compose([ + transforms.Resize((256,256)), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize +]) +ten_crop_aug = transforms.Compose([ + transforms.Resize(256), + transforms.TenCrop(256), + transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), + transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])), +]) +locname = args.locmodel +model = choose_clsmodel(locname) + +print(model) +model = model.to(0) +model.eval() +clsname = args.clsmodel +cls_model = choose_clsmodel(locname) +cls_model = cls_model.to(0) +cls_model.eval() + +root = './' +val_imagedir = os.path.join(root, 'val') + +anno_root = './anno/' +val_annodir = os.path.join(anno_root, 'val') + + +classes = os.listdir(val_imagedir) +classes.sort() +temp_softmax = nn.Softmax() +''' +savepath = 'ImageNet/Visualization/test_PSOL' +if not os.path.exists(savepath): + os.makedirs(savepath) +''' +#print(classes[0]) + + +class_to_idx = {classes[i]:i for i in range(len(classes))} + +result = {} + +accs = [] +accs_top5 = [] +loc_accs = [] +cls_accs = [] +final_cls = [] +final_loc = [] +final_clsloc = [] +final_clsloctop5 = [] +final_ind = [] +for k in range(1000): + cls = classes[k] + + total = 0 + IoUSet = [] + IoUSetTop5 = [] + LocSet = [] + ClsSet = [] + + files = os.listdir(os.path.join(val_imagedir, cls)) + files.sort() + + for (i, name) in enumerate(files): + # raw_img = cv2.imread(os.path.join(imagedir, cls, name)) + now_index = int(name.split('_')[-1].split('.')[0]) + final_ind.append(now_index-1) + xmlfile = os.path.join(val_annodir, cls, name.split('.')[0] + '.xml') + gt_boxes = get_cls_gt_boxes(xmlfile, cls) + if len(gt_boxes)==0: + continue + + raw_img = Image.open(os.path.join(val_imagedir, cls, name)).convert('RGB') + w, h = raw_img.size + + with torch.no_grad(): + img = transform(raw_img) + img = torch.unsqueeze(img, 0) + img = img.to(0) + reg_outputs = model(img) + + bbox = to_data(reg_outputs) + bbox = torch.squeeze(bbox) + bbox = bbox.numpy() + if TEN_CROP: + img = ten_crop_aug(raw_img) + img = img.to(0) + vgg16_out = cls_model(img) + vgg16_out = temp_softmax(vgg16_out) + vgg16_out = torch.mean(vgg16_out,dim=0,keepdim=True) + vgg16_out = torch.topk(vgg16_out, 5, 1)[1] + else: + img = cls_transform(raw_img) + img = torch.unsqueeze(img, 0) + img = img.to(0) + vgg16_out = cls_model(img) + vgg16_out = torch.topk(vgg16_out, 5, 1)[1] + vgg16_out = to_data(vgg16_out) + vgg16_out = torch.squeeze(vgg16_out) + vgg16_out = vgg16_out.numpy() + out = vgg16_out + ClsSet.append(out[0]==class_to_idx[cls]) + + #handle resize and centercrop for gt_boxes + for j in range(len(gt_boxes)): + temp_list = list(gt_boxes[j]) + raw_img_i, gt_bbox_i = ResizedBBoxCrop((256,256))(raw_img, temp_list) + raw_img_i, gt_bbox_i = CenterBBoxCrop((224))(raw_img_i, gt_bbox_i) + w, h = raw_img_i.size + + gt_bbox_i[0] = gt_bbox_i[0] * w + gt_bbox_i[2] = gt_bbox_i[2] * w + gt_bbox_i[1] = gt_bbox_i[1] * h + gt_bbox_i[3] = gt_bbox_i[3] * h + + gt_boxes[j] = gt_bbox_i + + w, h = raw_img_i.size + + bbox[0] = bbox[0] * w + bbox[2] = bbox[2] * w + bbox[0] + bbox[1] = bbox[1] * h + bbox[3] = bbox[3] * h + bbox[1] + + max_iou = -1 + for gt_bbox in gt_boxes: + iou = IoU(bbox, gt_bbox) + if iou > max_iou: + max_iou = iou + + LocSet.append(max_iou) + temp_loc_iou = max_iou + if out[0] != class_to_idx[cls]: + max_iou = 0 + + # print(max_iou) + result[os.path.join(cls, name)] = max_iou + IoUSet.append(max_iou) + #cal top5 IoU + max_iou = 0 + for i in range(5): + if out[i] == class_to_idx[cls]: + max_iou = temp_loc_iou + IoUSetTop5.append(max_iou) + #visualization code + ''' + opencv_image = deepcopy(np.array(raw_img_i)) + opencv_image = opencv_image[:, :, ::-1].copy() + for gt_bbox in gt_boxes: + cv2.rectangle(opencv_image, (int(gt_bbox[0]), int(gt_bbox[1])), + (int(gt_bbox[2]), int(gt_bbox[3])), (0, 255, 0), 4) + cv2.rectangle(opencv_image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), + (0, 255, 255), 4) + cv2.imwrite(os.path.join(savepath, str(name) + '.jpg'), np.asarray(opencv_image)) + ''' + cls_loc_acc = np.sum(np.array(IoUSet) > 0.5) / len(IoUSet) + final_clsloc.extend(IoUSet) + cls_loc_acc_top5 = np.sum(np.array(IoUSetTop5) > 0.5) / len(IoUSetTop5) + final_clsloctop5.extend(IoUSetTop5) + loc_acc = np.sum(np.array(LocSet) > 0.5) / len(LocSet) + final_loc.extend(LocSet) + cls_acc = np.sum(np.array(ClsSet))/len(ClsSet) + final_cls.extend(ClsSet) + print('{} cls-loc acc is {}, loc acc is {}, vgg16 cls acc is {}'.format(cls, cls_loc_acc, loc_acc, cls_acc)) + with open('inference_CorLoc.txt', 'a+') as corloc_f: + corloc_f.write('{} {}\n'.format(cls, loc_acc)) + accs.append(cls_loc_acc) + accs_top5.append(cls_loc_acc_top5) + loc_accs.append(loc_acc) + cls_accs.append(cls_acc) + if (k+1) %100==0: + print(k) + + +print(accs) +print('Cls-Loc acc {}'.format(np.mean(accs))) +print('Cls-Loc acc Top 5 {}'.format(np.mean(accs_top5))) + +print('GT Loc acc {}'.format(np.mean(loc_accs))) +print('{} cls acc {}'.format(cls_model_name, np.mean(cls_accs))) +with open('origin_result.txt', 'w') as f: + for k in sorted(result.keys()): + f.write('{} {}\n'.format(k, str(result[k]))) \ No newline at end of file diff --git a/utils/IoU.py b/utils/IoU.py new file mode 100644 index 0000000..eced854 --- /dev/null +++ b/utils/IoU.py @@ -0,0 +1,106 @@ +import numpy as np +import xml.etree.ElementTree as ET + +def get_gt_boxes(xmlfile): + '''get ground-truth bbox from VOC xml file''' + tree = ET.parse(xmlfile) + objs = tree.findall('object') + num_objs = len(objs) + gt_boxes = [] + for obj in objs: + bbox = obj.find('bndbox') + x1 = float(bbox.find('xmin').text)-1 + y1 = float(bbox.find('ymin').text)-1 + x2 = float(bbox.find('xmax').text)-1 + y2 = float(bbox.find('ymax').text)-1 + + gt_boxes.append((x1, y1, x2, y2)) + return gt_boxes + +def get_cls_gt_boxes(xmlfile, cls): + '''get ground-truth bbox from VOC xml file''' + tree = ET.parse(xmlfile) + objs = tree.findall('object') + num_objs = len(objs) + gt_boxes = [] + for obj in objs: + bbox = obj.find('bndbox') + cls_name = obj.find('name').text + #print(cls_name, cls) + if cls_name != cls: + continue + x1 = float(bbox.find('xmin').text)-1 + y1 = float(bbox.find('ymin').text)-1 + x2 = float(bbox.find('xmax').text)-1 + y2 = float(bbox.find('ymax').text)-1 + + gt_boxes.append((x1, y1, x2, y2)) + if len(gt_boxes)==0: + pass + #print('%s bbox = 0'%cls) + + return gt_boxes + +def get_cls_and_gt_boxes(xmlfile, cls,class_to_idx): + '''get ground-truth bbox from VOC xml file''' + tree = ET.parse(xmlfile) + objs = tree.findall('object') + num_objs = len(objs) + gt_boxes = [] + for obj in objs: + bbox = obj.find('bndbox') + cls_name = obj.find('name').text + #print(cls_name, cls) + if cls_name != cls: + continue + x1 = float(bbox.find('xmin').text)-1 + y1 = float(bbox.find('ymin').text)-1 + x2 = float(bbox.find('xmax').text)-1 + y2 = float(bbox.find('ymax').text)-1 + + gt_boxes.append((class_to_idx[cls_name],[x1, y1, x2-x1, y2-y1])) + if len(gt_boxes)==0: + pass + #print('%s bbox = 0'%cls) + + return gt_boxes +def convert_boxes(boxes): + ''' convert the bbox to the format (x1, y1, x2, y2) where x1,y10 + if aarea + barea - inter <=0: + print(a) + print(b) + o = inter / (aarea+barea-inter) + #if w<=0 or h<=0: + # o = 0 + return o \ No newline at end of file diff --git a/utils/__pycache__/IoU.cpython-36.pyc b/utils/__pycache__/IoU.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d425085598a1015c4e7a1355ac02ea610ad46c9a GIT binary patch literal 2676 zcmb_e&5zqe6rUN}vEwA|QnuZc@?lU}8q}`TMe3odD(J@ns9Hn|l~!9t-i))msS~GT zFZEXA(^ecga_pTuN02yj;J|sVoc3Sf#P3bKX<8wu5Vri@%+KSQH}iWR{;1Ol&t3ZD z!Otzm{$LB2K>K4nISI)mPuT{yT8x~LyWvS8ou_Q0AzkS~`?4W@=s*Ut3Eh;TY(a;T z-(|h_A!cE>2b>fxti~F6uH(s{K@@DnH9O!H*L)--M@tlr7SPT%tK72G;s$##tvv0G zTuFGN@N@%h=YWqI2VA;cR{1sNEnq>7Va$A7v&Q@dEY$AE*S-#PQ-?$MfOGa}^*+-s z><^xNu$>gqw#xP<@l@Pp8R5 z_FBfra&|ShP)TB3teKl&vN!I}TH8e*YbUvJ&^IobYD9U_h=Z+@D6wKm;1ZEU}PG^K zc8gP7y^VJRav{1pc;;mPg3g?5p+!x$1F}J6<$`Rl{7yR{``fBv2_~}rIn`bjfFl0T zQJ+0pEgIT~3xKQ7}0%%wQ4L_z}j#wpXk>u-G1%@u#LWIJ{0OF z?ZF+wS9S+QhoBq?<4)pnqUh|DM}m$-`6L1o4+qH=@b(+K*B+if^ADD@^&@xNJGktKxHO-hu z5HE7w76BIvjJI9&9>xr4r*x5_v@ILpzoRkd96Z;YyEs=cS9##PQ9jcyIR73@6Qjg= zU|SU4UeyFE&p8iX;+*&;Y>{)~m#{_7+cmbxc?X{5YoVK`*^GA`cA4(zR+nwNL)bA( z1F)5n4Eexsy73=$vZ%YSwV^)4Cl3cj#>9k+ZaF1M5|euv33H`^*Ld#WoX^G248N zt8?hRkSQrQ*nz*IQ810t3(w>v4+M%LA;Kb%KpdPAla56ofw&fNEdmL|lb&omMU)ko z4JRnNm}7t*+icTa|lgkxkXX3!gLc~@P}Oe2yXD6Gibc$ zc=LS38CJD*8w~lkZGy9z?~rohA!UlWdr*fqeiu)E1Hzg$+o(N&E*Z;fls3yGwq{Xs z5;BZKQ`&$TS}bF#o>wp0!-VQX60}B7s8woHit0JWv9(d6KFa-(vMXV ze~IzAWsjf29pHX)n#$vqG#kWeesyheG`z{K@;<&ofmgeeqSrAzHhk-E9cPWk@yCYm z7!P9;wc*9I>V~&P%JA%E;~cF0IoU+TmWdPG-XaK~-6gWmv63%2>Jmoax52|70*bkq zOwCFBL98Y8_po*|O~%QjxJQ2yo`poS7ErjPj{ed!Mx`qn4^g}$`k%(~lPiF2JZXJ*cKW}XZN;h%qe^X6-xkiWB|meM+UMB8Ay6aqVNiH$w5M`z$ZXhhr9sQ z8Z09%oqwmvOY-J<%jVBCgSOjp%_W^O_26t%b4@2?L8cw`R8mbf-y@cLWQ$1l72#x$ zoNXSFv+xN1j9}JCb`D8{*a=$@$)$Klq$mCNyoFCv%mZO}!GW?HyRvMxMQhljAA;%t zO#c_o4wN1EokY}7sW^1U{{HOcv+>d458ppZDw#e?@-!~TuV9bQn=G%#C+B6dGrKgw zULF^3WxT&`Q-jDa+evj^HpfLg-`UHnB+l!7D9fnV5i@+6UYd?dFH%*fSO4IPxDjyH z#;+^Yq|$ga z2om%TRktvf*mK*@hFjg?_c;_vCCf=cTSkdynl7ltZo)KKP zzKz7_c0EO*RJkM5q$*}roi$liDq9=FPqJpzbM7)NoBBBN4E@p2>Gkj|i4DI*+DZHj$$a!H7FbNPVX6z*-$+4Te36JH~w=-jE5kTXqtMV5CK3YBFKARazYETI0^g z;1h_2g#z&Rto59Q=5m&AEs)S?$rI$T-s`bW< zegfK{>Py6}+G7A?)gHgX06=8STGsLgjbK4XSuMC?NB}*xjW7{_(097$^bT|n;Q9-S zxd9*roI2>C&Ds!ppu1?0P@+YCLqly@e-E(rCSnr+gMlvur<%m^YF~`PA5Fs`}zKw0sMk{hob%pN%oO(;A944!YOB8 zVN5wWg@Lm|1dcB-(jN2BLwb!Q?+Cyg6nr6^Tv5@Rz$whhB{iNU`Y|JX!#p490W4W# zQ)@eo&8+P4R%GqdrSo*dC!3qR5rcmR)se?6w2`pQ;Q2kNCrGq^= zw?y2kfCq?$s&$DqRCs{h5!0uqn6Z!t*lEjTK#VT8Eaq6uu~^t(ZNz-+%*44C+iI|m z#XO677HcZq;FPq5DRKVMh18(&Xts~-MxN9`!zX7X*O&L$JiVrL&7!g3d5 z1GTH@m6^nz3e=wY|EIe%P%YeB_gLU^U}c$E^U7>+#gR>r>`A|`*XtG1T9}poj}2Oy z&5~Txywq+iL|kh>786;BxprjwIniERmPyXFpT$+0EizQn+$!$PG!e<5uiXztl}PQ$ zi$WMVpQ4y*S{-RxBUM9`!!hPhVyWFi@I;Kmt0XDUH<2VAv#L58v${rBRYRiHXY#PDHA86zC5L80q)i+W(jKvw?$eO&qNhg!(xs1xy|hpIB%seom$a#Pictfm i-F%UiwFuC%F5A#O_5=YPyvPc^m?baGj-=U9K>q^7yxSH4 literal 0 HcmV?d00001 diff --git a/utils/__pycache__/vis.cpython-36.pyc b/utils/__pycache__/vis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98e06fd744d378b0e8b2b6be0e6e78c3ff1af18c GIT binary patch literal 794 zcmYjPO>fgc5Zztx+K!u)g2YF~r3VhRM5`7gga9EV5`hp3sFkQ>rLtV_CO9~DSld-n z*{5*jhj8Klu;Ruke}NM-&VlY~-n^NWX5Y;2<6h7E{qyj4IE`ld(5OG-ETB-l|81iBR$|9$xeyiTSGnKQ?otS zNmhpADvQe`&xY$;kYE2oxPd?cVh(`I1#`e&gB5Q9YBD7yow5~eNKHq);f-kQ#wo2S zuSqFtzM_C)MaK-h-9`>ZHLER2zY7XAKo7gR_bDl=C{3=^UbJ_SSD8$*36KSd?3ACA z_)DgW!nchb73FfS49n(5R3QGTvEs`|rlVC{M%g4)#)7Qxf}2eYzknJ>#Ce+QK(vJn zTO1j-1X~<~KmI}VZD{82gZ{J8+3AP(2XQXdL7b{68-0X$RFz3uj4qSnaK1Dy_z;VF zI0YODVO`p#52?NjJjN%(g+s&Bc34kL#QL ztVgIv_o==Q%(pbw)hJtwO# 0: + i = order[-1] + keep_boxes.append(boxes[i]) + + xx1 = np.maximum(x1[i], x1[order[:-1]]) + yy1 = np.maximum(y1[i], y1[order[:-1]]) + xx2 = np.minimum(x2[i], x2[order[:-1]]) + yy2 = np.minimum(y2[i], y2[order[:-1]]) + + w = np.maximum(0.0, xx2-xx1+1) + h = np.maximum(0.0, yy2-yy1+1) + inter = w*h + + ovr = inter / (areas[i] + areas[order[:-1]] - inter) + inds = np.where(ovr <= thresh) + order = order[inds] + + return keep_boxes diff --git a/utils/vis.py b/utils/vis.py new file mode 100644 index 0000000..fb39ae9 --- /dev/null +++ b/utils/vis.py @@ -0,0 +1,22 @@ +from __future__ import print_function + +import cv2 +import numpy as np +import os + +_GREEN = (18, 217, 15) +_RED = (15, 18, 217) + +def vis_bbox(img, bbox, color=_GREEN, thick=1): + '''Visualize a bounding box''' + img = img.astype(np.uint8) + (x0, y0, x1, y1) = bbox + cv2.rectangle(img, (int(x0), int(y0)), (int(x1), int(y1)), color, thickness=thick) + return img + +def vis_one_image(img, boxes, color=_GREEN): + for bbox in boxes: + img = vis_bbox(img, (bbox[0], bbox[1], bbox[2], bbox[3]), color) + return img + +