Skip to content

Commit

Permalink
restructure semi-supervised code and remove minor run-time errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Navaneet committed Oct 20, 2022
1 parent 737f40b commit ab07cfe
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 18 deletions.
13 changes: 8 additions & 5 deletions data_loader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import os
import pdb

import random
import torch
from torchvision import transforms, datasets
from PIL import ImageFilter

from .util import subset_classes
from util import subset_classes


# Extended version of ImageFolder to return index of image too.
class ImageFolderEx(datasets.ImageFolder):
# def __init__(self, root, sup_split_file, only_sup, *args, **kwargs):
# super(ImageFolderEx, self).__init__(root, *args, **kwargs)
def __init__(self, root, transforms, sup_split_file=None, only_sup=False, corrupt_split_file=None):
super(ImageFolderEx, self).__init__(root, transforms)

Expand Down Expand Up @@ -145,15 +144,19 @@ def get_train_loader(opt):
# Applicable only for semi-supervised setup.
if 'sup_split_file' in vars(opt).keys():
# Get dataloader for pseudo-labelling
sup_val_dataset = ImageFolderEx(traindir, opt.sup_split_file, True, transforms.Compose(augmentation_weak))
sup_val_dataset = ImageFolderEx(
root=traindir,
sup_split_file=opt.sup_split_file,
only_sup=True,
transforms=transforms.Compose(augmentation_weak)
)
if opt.dataset == 'imagenet100':
subset_classes(sup_val_dataset, num_classes=100)
train_val_loader = torch.utils.data.DataLoader(
sup_val_dataset,
batch_size=opt.batch_size, shuffle=False,
num_workers=opt.num_workers, pin_memory=True,
)

return train_loader, train_val_loader
else:
return train_loader
Expand Down
7 changes: 4 additions & 3 deletions semi_supervised/run_fullprecision.sh → run_fullprecision.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ set -e
base_dir='./'
exp_name='exp_semi_sup_fullprec_1'
dataset_path='path/to/imagenet/dataset'
#dataset_path='/datasets/imagenet'

CUDA_VISIBLE_DEVICES=0,1,2,3 python train_pseudo_cmsf.py \
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m semi_supervised.train_pseudo_cmsf \
--base-dir $base_dir \
--exp $exp_name\
--learning_rate 0.05\
Expand All @@ -29,10 +30,10 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python train_pseudo_cmsf.py \
--mem_bank_size 128000 \
--sup_mem_bank_size 128000 \
--save_freq 10\
--sup-split-file 'imagenet_subsets/1p_10p/subsets/10percent.txt' \
--sup-split-file 'semi_supervised/imagenet_subsets/1p_10p/subsets/10percent.txt' \
$dataset_path

exp="$base_dir/semi_sup_cmsf/exp/$exp_name"
exp="$base_dir/exp/semi_sup_cmsf/$exp_name"
ep=200

CUDA_VISIBLE_DEVICES=0,1 python eval_linear.py\
Expand Down
6 changes: 3 additions & 3 deletions semi_supervised/pseudo_cmsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torch.nn as nn
import torch.nn.functional as F

import .models.resnet as resnet
from .models.mlp_arch import get_mlp, get_mlp_3l
import models.resnet as resnet
from models.mlp_arch import get_mlp, get_mlp_3l

from .util import get_shuffle_ids
from util import get_shuffle_ids


class PseudoCMSF(nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion semi_supervised/pseudo_label_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
import torch.nn.functional as F

from util import AverageMeterV2, ProgressMeter, accuracy
from util import AverageMeterV2, accuracy
from tools import ProgressMeter


def train_pseudo_lbl(pseudo_cmsf, train_val_loader, backbone, model, logger, opt):
Expand Down
10 changes: 5 additions & 5 deletions semi_supervised/train_pseudo_cmsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
import torch
import torch.backends.cudnn as cudnn

from .data_loader import get_train_loader
from pseudo_cmsf import PseudoCMSF
from pseudo_label_train import train_pseudo_lbl
from .util import adjust_learning_rate, AverageMeter
from .tools import get_logger
from data_loader import get_train_loader
from semi_supervised.pseudo_cmsf import PseudoCMSF
from semi_supervised.pseudo_label_train import train_pseudo_lbl
from util import adjust_learning_rate, AverageMeter
from tools import get_logger


def parse_option():
Expand Down
26 changes: 25 additions & 1 deletion util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,30 @@ def update(self, val, n=1):
self.avg = self.sum / self.count


class AverageMeterV2(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()

def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0

def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
Expand All @@ -56,7 +80,7 @@ def accuracy(output, target, topk=(1,)):

res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

Expand Down

0 comments on commit ab07cfe

Please sign in to comment.