Skip to content

Commit

Permalink
update code for NIPS submission
Browse files Browse the repository at this point in the history
  • Loading branch information
ynahshan committed May 24, 2019
1 parent e853c3a commit b92f939
Show file tree
Hide file tree
Showing 6 changed files with 429 additions and 186 deletions.
119 changes: 92 additions & 27 deletions inference/inference_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,27 @@
from utils.model_naming import set_node_names
import numpy as np
from utils.dump_manager import DumpManager as DM
import pretrainedmodels
import pretrainedmodels.utils as mutils
from pathlib import Path

import mlflow


torch.backends.cudnn.deterministic = True


home = str(Path.home())
IMAGENET_FOR_INFERENCE = '/home/cvds_lab/datasets/ILSVRC2012/'

mlflow.set_tracking_uri(os.path.join(home, 'mlruns_mxt'))

model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
model_names.append('shufflenet')
model_names.append('mobilenetv2')
model_names+=pretrainedmodels.model_names

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--data', metavar='DIR', default=IMAGENET_FOR_INFERENCE,
Expand Down Expand Up @@ -75,17 +86,27 @@
parser.add_argument('--rho_act', '-ra', default=None, type=float, help='Rho parameter for activations clipping')
parser.add_argument('--rho_weight', '-rw', default=None, type=float, help='Rho parameter for weights clipping')
parser.add_argument('--stats_mode', '-sm', default='no', help='Specify if collect stats, use or not stats: [collect, use, no]')
parser.add_argument('--stats_kind', '-sk', default='avg', help='Specify kind of stats to use: [avg, max]')
parser.add_argument('--stats_kind', '-sk', default='mean', help='Specify kind of stats to use: [mean, max]')
parser.add_argument('--stats_folder', '-sf', default=None, help='Specify directory of for statistics')
parser.add_argument('--stats_batch_avg', '-sba', action='store_true', help='Whether average statistics across the batch')
parser.add_argument('--custom_test', '-ct', action='store_true', default=False, help='Perform some custom test.')
parser.add_argument('--dump_dir', '-dd', default=None, help='Directory to dump tensors')
parser.add_argument('--measure_stats', '-m', action='store_true', help='Measure statistics of activations during runtime', default=False)
parser.add_argument('--measure_stats_folder', '-mf', help='Folder to save measured statistics of activations during runtime', default=None)
parser.add_argument('--kld_threshold', '-kld', action='store_true', help='Measure statistics of activations during runtime', default=False)
parser.add_argument('--aciq_cal', '-ac', action='store_true', help='Enable aciq calibration mode', default=False)
parser.add_argument('--cal_set_size', '-cs', default=5120, type=int, help='Size of calibration set for threshold evaluation (default: 2048)')
parser.add_argument('--subset', '-ss', default=None, type=int, help='Run on subset of data')
parser.add_argument('--per_channel_quant_weights', '-pcq_w', action='store_true', help='Per channel quantization of weights', default=False)
parser.add_argument('--per_channel_quant_act', '-pcq_a', action='store_true', help='Per channel quantization of activations', default=False)
parser.add_argument('--bit_alloc_act', '-baa', action='store_true', help='Optimal bit allocation for each channel of activations', default=False)
parser.add_argument('--bit_alloc_weight', '-baw', action='store_true', help='Optimal bit allocation for each channel of weights', default=False)
parser.add_argument('--bit_alloc_rmode', '-bam', help='One of [round, ceil]', default='ceil')
parser.add_argument('--bit_alloc_prior', '-bap', help='One of [gaus, laplace]', default='gaus')
parser.add_argument('--bias_corr_act', '-bca', action='store_true', help='Bias correction for activations', default=False)
parser.add_argument('--bias_corr_weight', '-bcw', action='store_true', help='Bias correction for weights', default=False)
parser.add_argument('--var_corr_weight', '-vcw', action='store_true', help='Variance correction for weights', default=False)
parser.add_argument('--mlf_experiment', '-mlexp', help='Name of experiment', default=None)
args = parser.parse_args()

if args.arch == 'resnet50':
Expand Down Expand Up @@ -125,7 +146,22 @@ def __init__(self):

# create model
print("=> using pre-trained model '{}'".format(args.arch))
self.model = models.__dict__[args.arch](pretrained=True)
if args.arch == 'shufflenet':
import models.ShuffleNet as shufflenet
self.model = shufflenet.ShuffleNet(groups=8)
params = torch.load('ShuffleNet_1g8_Top1_67.408_Top5_87.258.pth.tar')
self.model = torch.nn.DataParallel(self.model, args.device_ids)
self.model.load_state_dict(params)
elif args.arch == 'mobilenetv2':
from models.MobileNetV2 import MobileNetV2 as mobilenetv2
self.model = mobilenetv2()
params = torch.load('mobilenetv2_Top1_71.806_Top2_90.410.pth.tar')
self.model = torch.nn.DataParallel(self.model, args.device_ids)
self.model.load_state_dict(params)
elif args.arch not in models.__dict__ and args.arch in pretrainedmodels.model_names:
self.model = pretrainedmodels.__dict__[args.arch](num_classes=1000, pretrained='imagenet')
else:
self.model = models.__dict__[args.arch](pretrained=True)

set_node_names(self.model)

Expand All @@ -139,16 +175,17 @@ def __init__(self):
search_absorbe_bn(self.model)
QM().bn_folding = True

if args.qmodel is not None:
model_q_path = os.path.join(os.path.join(home, 'mxt-sim/models'), args.arch + '_kmeans%dbit%s.pt' % (args.qmodel, ('' if args.no_bias_corr else '_bcorr')))
model_q = torch.load(model_q_path)
qldict = set_node_names(model_q, create_ldict=True)
QM().ql_dict = qldict
model_q.to(args.device)
self.model.load_state_dict(model_q.state_dict())
del model_q
# if args.qmodel is not None:
# model_q_path = os.path.join(os.path.join(home, 'mxt-sim/models'), args.arch + '_lowp_pcq%dbit%s.pt' % (args.qmodel, ('' if args.no_bias_corr else '_bcorr')))
# model_q = torch.load(model_q_path)
# qldict = set_node_names(model_q, create_ldict=True)
# QM().ql_dict = qldict
# model_q.to(args.device)
# self.model.load_state_dict(model_q.state_dict())
# del model_q

self.model.to(args.device)
QM().quantize_model(self.model)

if args.device_ids and len(args.device_ids) > 1 and args.arch != 'shufflenet' and args.arch != 'mobilenetv2':
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
Expand All @@ -165,20 +202,24 @@ def __init__(self):
# Data loading code
valdir = os.path.join(args.data, 'val')

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
resize = 256 if args.arch != 'inception_v3' else 299
crop_size = 224 if args.arch != 'inception_v3' else 299
tfs = [
transforms.Resize(resize),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
]
if args.arch not in models.__dict__ and args.arch in pretrainedmodels.model_names:
dataparallel = args.device_ids is not None and len(args.device_ids) > 1
tfs = [mutils.TransformImage(self.model.module if dataparallel else self.model)]
else:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
resize = 256 if args.arch != 'inception_v3' else 299
crop_size = 224 if args.arch != 'inception_v3' else 299
tfs = [
transforms.Resize(resize),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
]

self.val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose(tfs)),
batch_size=args.batch_size, shuffle=(True if args.kld_threshold or args.aciq_cal else False or args.shuffle),
batch_size=args.batch_size, shuffle=(True if (args.kld_threshold or args.aciq_cal or args.shuffle) else False),
num_workers=args.workers, pin_memory=True)

def run(self):
Expand Down Expand Up @@ -220,6 +261,10 @@ def run(self):
print(elog)
else:
val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
if mlflow.active_run() is not None:
mlflow.log_metric('top1', val_prec1)
mlflow.log_metric('top5', val_prec5)
mlflow.log_metric('loss', val_loss)
return val_loss, val_prec1, val_prec5


Expand All @@ -240,13 +285,15 @@ def validate(val_loader, model, criterion):
with torch.no_grad():
end = time.time()
for i, (input, target) in enumerate(val_loader):
if args.stats_mode == 'collect' and i*args.batch_size >= args.cal_set_size and (args.kld_threshold or args.aciq_cal):
if (args.stats_mode == 'collect' and i*args.batch_size >= args.cal_set_size and (args.kld_threshold or args.aciq_cal)) or \
(args.subset is not None and i*args.batch_size >= args.subset):
break
# Uncomment to enable dump
# QM().disable()
# if i > 0:
# break

if i == 0:
QM().verbose = True
input = input.to(args.device)
target = target.to(args.device)
if args.dump_dir is not None and i == 5:
Expand All @@ -259,6 +306,7 @@ def validate(val_loader, model, criterion):
output = model(input)

QM().reset_counters()
QM().verbose = False

loss = criterion(output, target)

Expand Down Expand Up @@ -294,7 +342,14 @@ def get_params():
'true_zero': args.preserve_zero,
'kld': args.kld_threshold,
'pcq_weights': args.per_channel_quant_weights,
'pcq_act': args.per_channel_quant_act
'pcq_act': args.per_channel_quant_act,
'bit_alloc_act': args.bit_alloc_act,
'bit_alloc_weight': args.bit_alloc_weight,
'bit_alloc_rmode': args.bit_alloc_rmode,
'bit_alloc_prior': args.bit_alloc_prior,
'bcorr_act': args.bias_corr_act,
'bcorr_weight': args.bias_corr_weight,
'vcorr_weight': args.var_corr_weight
},
'qmanager':{
'rho_act': args.rho_act,
Expand All @@ -304,6 +359,16 @@ def get_params():
return qparams

if __name__ == '__main__':
with QM(args, get_params()):
im = InferenceModel()
im.run()
if args.stats_mode != 'collect':
mlflow.set_experiment(args.arch if args.mlf_experiment is None else args.mlf_experiment)
with mlflow.start_run(run_name="{}_W{}_A{}".format(args.arch, args.qweight, args.qtype)):
params = vars(args)
for p in params:
mlflow.log_param(p, params[p])
with QM(args, get_params()):
im = InferenceModel()
im.run()
else:
with QM(args, get_params()):
im = InferenceModel()
im.run()
Loading

0 comments on commit b92f939

Please sign in to comment.