Skip to content

Commit

Permalink
Fix CH CUB DS for NTS-Net, 2
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed May 23, 2019
1 parent a262c49 commit f167a94
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
2 changes: 1 addition & 1 deletion chainer_/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_val_data_source(ds_metainfo,
repeat=False,
shuffle=False,
n_processes=num_workers,
shared_mem=300000000)
shared_mem=1000000)
return {
# "transform": transform,
"iterator": iterator,
Expand Down
33 changes: 21 additions & 12 deletions chainer_/datasets/cub200_2011_cls_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import os
import numpy as np
import pandas as pd
from chainer.dataset import DatasetMixin
from chainercv.chainer_experimental.datasets.sliceable import GetterDataset
from chainercv.utils import read_image
from .imagenet1k_cls_dataset import ImageNet1KMetaInfo


class CUB200_2011(DatasetMixin):
class CUB200_2011(GetterDataset):
"""
CUB-200-2011 fine-grained classification dataset.
Expand Down Expand Up @@ -82,23 +82,32 @@ def __init__(self,

self._transform = transform

def __getitem__(self, index):
image_file_name = self.image_file_names[index]
image_file_path = os.path.join(self.images_dir_path, image_file_name)
img = read_image(image_file_path, color=True)
label = int(self.class_ids[index])
self.add_getter('img', self._get_image)
self.add_getter('label', self._get_label)

def _get_image(self, i):
image_file_name = self.image_file_names[i]
image_file_path = os.path.join(self.images_dir_path, image_file_name)
image = read_image(image_file_path, color=True)
if self._transform is not None:
img = self._transform(img)
image = self._transform(image)
return image

return img, label
def _get_label(self, i):
label = int(self.class_ids[i])
return label

def __len__(self):
return len(self.image_ids)

def get_example(self, i):
image, label = self[i]
return image, label
# def __getitem__(self, i):
# image = self._get_image(i)
# label = self._get_label(i)
# return image, label
#
# def get_example(self, i):
# image, label = self[i]
# return image, label


class CUB200MetaInfo(ImageNet1KMetaInfo):
Expand Down

0 comments on commit f167a94

Please sign in to comment.