Skip to content

Commit

Permalink
Fixed bug with filtering by phase. Added phase param to getImgs().
Browse files Browse the repository at this point in the history
  • Loading branch information
mricha56 committed Jul 19, 2017
1 parent 8e29a42 commit e7c2cf4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 25 deletions.
47 changes: 22 additions & 25 deletions PythonAPI/detail/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,36 +68,19 @@ def __init__(self, annotation_file='json/trainval_merged.json',

def __createIndex(self):
# create index
tic = time.time()
print('creating index...')

# create class members
self.cats,self.imgs,self.segmentations,self.occlusion,self.parts= {},{},{},{},{}

phases = []
if "train" in self.phase: phases.append("train")
if "val" in self.phase: phases.append("val")
if "test" in self.phase: phases.append("test")
assert len(phases) > 0, 'Invalid phase, {}'.format(self.phase)

# Filter images and annotations according to phase
# Organize data into instance variables
for img in self.data['images']:
if img['phase'] not in phases:
self.data['images'].remove(img)
else:
self.imgs[img['image_id']] = img

imgIds = list(self.imgs.keys())
self.imgs[img['image_id']] = img
for segm in self.data['annos_segmentation']:
if segm['image_id'] not in imgIds:
self.data['annos_segmentation'].remove(segm)
else:
self.segmentations[segm['id']] = segm

self.segmentations[segm['id']] = segm
for occl in self.data['annos_occlusion']:
if occl['image_id'] not in imgIds:
self.data['annos_occlusion'].remove(occl)
else:
self.occlusion[occl['image_id']] = occl
self.occlusion[occl['image_id']] = occl

# Follow references
for img in self.data['images']:
Expand Down Expand Up @@ -147,7 +130,7 @@ def __createIndex(self):
img = self.imgs[occl['image_id']]
img['annotations'].append(occl_id)

print('index created!')
print('index created! {:0.2f}s'.format(time.time() - tic))

def info(self):
"""
Expand Down Expand Up @@ -490,15 +473,24 @@ def getParts(self, parts=[], cat=None, superpart=None):

return parts

def getImgs(self, imgs=[], cats=[], supercat=None):
def getImgs(self, imgs=[], cats=[], supercat=None, phase=None):
'''
Get images that satisfy given filter conditions.
:param imgs (int/string/dict array) : get imgs with given ids
:param cats (int/string/dict array) : get imgs with all given cats
:param supercat (string) : get imgs with the given supercategory
:param phase (string) : filter images by phase. If None, the phase
provided to the Detail() constructor is used.
:return: images (dict array) : array of image dicts
'''
if phase is None:
phase = self.phase
phases = []
if "train" in phase: phases.append("train")
if "val" in phase: phases.append("val")
if "test" in phase: phases.append("test")
assert len(phases) > 0, 'Invalid phase, {}'.format(phase)

imgs = self.__toList(imgs)
if len(imgs) == 0:
imgs = list(self.imgs.values())
Expand Down Expand Up @@ -526,6 +518,11 @@ def getImgs(self, imgs=[], cats=[], supercat=None):
if len(catIds & set(img['categories'])) == 0:
imgs.remove(img)

oldimgs = imgs.copy()
for img in oldimgs:
if img['phase'] not in phases:
imgs.remove(img)

return imgs

def decodeMask(self, json):
Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ Python API for the [PASCAL in Detail](https://sites.google.com/view/pasd/dataset
## To install:
Run `make` and `make install` under the [PythonAPI](PythonAPI/) directory.

In Python:

```python
from detail import Detail
details = Detail('json/trainval_merged.json', 'VOCdevkit/VOC2010/JPEGImages')
```

If you wish to use the API from MATLAB, see [MATLAB's documentation for calling Python code](https://www.mathworks.com/help/matlab/matlab_external/call-python-from-matlab.html). The Detail API no longer maintains a separate MATLAB API.

## To see a demo:
Expand Down
2 changes: 2 additions & 0 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,5 @@ def printProgress(count, blockSize, totalSize):
urlretrieve(url, filepath + '.download', reporthook=printProgress)
os.rename(filepath + '.download', filepath)
print("Download complete!")
else:
print('Don\'t recognize dataset %s' % sys.argv[1].lower())

0 comments on commit e7c2cf4

Please sign in to comment.