Skip to content

Commit

Permalink
add support for imagenet
Browse files Browse the repository at this point in the history
  • Loading branch information
ruthcfong committed Nov 4, 2017
1 parent 6f441f2 commit a759612
Showing 1 changed file with 55 additions and 37 deletions.
92 changes: 55 additions & 37 deletions src/linearprobe_disc_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,25 +182,42 @@ def linear_probe_discriminative(directory, blob, label_i, suffix='', batch_size=
# print np.max(thresh), thresh.shape, type(thresh)
# Map the blob activation data for reading
fn_read = ed.mmap_filename(blob=blob)
# Load the dataset
ds = loadseg.SegmentationData(info.dataset)
# Get all the categories the label is a part of
label_categories = ds.label[label_i]['category'].keys()
num_categories = len(label_categories)
# Get label name
label_name = ds.name(category=None, j=label_i)

blobdata = cached_memmap(fn_read, mode='r', dtype='float32', shape=shape)
# Get indices of images containing the given label
if not has_image_to_label(directory):
print('image_to_label does not exist in %s; creating it now...' % directory)
create_image_to_label(directory, batch_size=batch_size)
image_to_label = load_image_to_label(directory)
label_idx = np.where(image_to_label[:, label_i])[0]
non_label_idx = np.where(image_to_label[:, label_i] == 0)[0]
if 'broden' in info.dataset:
# Load the dataset
ds = loadseg.SegmentationData(info.dataset)
# Get all the categories the label is a part of
label_categories = ds.label[label_i]['category'].keys()
num_categories = len(label_categories)
# Get label name
label_name = ds.name(category=None, j=label_i)

# Get indices of images containing the given label
if not has_image_to_label(directory):
print('image_to_label does not exist in %s; creating it now...' % directory)
create_image_to_label(directory, batch_size=batch_size)
image_to_label = load_image_to_label(directory)
label_idx = np.where(image_to_label[:, label_i])[0]
non_label_idx = np.where(image_to_label[:, label_i] == 0)[0]
print('Number of positive and negative examples of label %d (%s): %d %d' % (
label_i, label_name, len(label_idx), len(non_label_idx)))
elif 'imagenet' in info.dataset or 'ILSVRC' in info.dataset:
# TODO: don't hardcode
label_desc = np.loadtxt('/users/ruthfong/packages/caffe/data/ilsvrc12/synset_words.txt', str, delimiter='\t')
label_name = ' '.join(label_desc[label_i].split(',')[0].split()[1:])
image_to_label_train = load_image_to_label(os.path.join(directory, 'train'))
train_label_idx = np.where(image_to_label_train[:, label_i])[0]
train_non_label_idx = np.where(image_to_label_train[:, label_i])[0]
image_to_label_val = load_image_to_label(os.path.join(directory, 'val'))
val_label_idx = np.where(image_to_label_val[:, label_i])[0]
val_non_label_idx = np.where(image_to_label_val[:, label_i])[0]
print('Number of positive and negative examples of label %d (%s): %d %d %d %d' %
(label_i, label_name, len(train_label_idx), len(train_non_label_idx),
len(val_label_idx), len(val_non_label_idx)))
else:
assert(False)

print('Number of positive and negative examples of label %d (%s): %d %d' % (
label_i, label_name, len(label_idx), len(non_label_idx)))
blobdata = cached_memmap(fn_read, mode='r', dtype='float32', shape=shape)

criterion = torch.nn.BCEWithLogitsLoss()
if num_filters is not None:
Expand All @@ -225,26 +242,27 @@ def linear_probe_discriminative(directory, blob, label_i, suffix='', batch_size=
optimizer = Custom_Adam(layer.parameters(), lr, l1_weight_decay=l1_weight_decay,
l2_weight_decay=l2_weight_decay, lower_bound=lower_bound)

train_label_idx = []
val_label_idx = []
for ind in label_idx:
if ds.split(ind) == 'train':
train_label_idx.append(ind)
elif ds.split(ind) == 'val':
val_label_idx.append(ind)

train_non_label_idx = []
val_non_label_idx = []
for ind in non_label_idx:
if ds.split(ind) == 'train':
train_non_label_idx.append(ind)
elif ds.split(ind) == 'val':
val_non_label_idx.append(ind)

train_label_idx = np.array(train_label_idx)
val_label_idx = np.array(val_label_idx)
train_non_label_idx = np.array(train_non_label_idx)
val_non_label_idx = np.array(val_non_label_idx)
if 'broden' in ds.dataset:
train_label_idx = []
val_label_idx = []
for ind in label_idx:
if ds.split(ind) == 'train':
train_label_idx.append(ind)
elif ds.split(ind) == 'val':
val_label_idx.append(ind)

train_non_label_idx = []
val_non_label_idx = []
for ind in non_label_idx:
if ds.split(ind) == 'train':
train_non_label_idx.append(ind)
elif ds.split(ind) == 'val':
val_non_label_idx.append(ind)

train_label_idx = np.array(train_label_idx)
val_label_idx = np.array(val_label_idx)
train_non_label_idx = np.array(train_non_label_idx)
val_non_label_idx = np.array(val_non_label_idx)

num_train_labels = len(train_label_idx)
num_train_non_labels = len(train_non_label_idx)
Expand Down

0 comments on commit a759612

Please sign in to comment.