Skip to content

Commit

Permalink
update code with VLC
Browse files Browse the repository at this point in the history
  • Loading branch information
ynahshan committed Jul 31, 2019
1 parent 0a7dd3f commit c4a930d
Show file tree
Hide file tree
Showing 20 changed files with 311 additions and 119 deletions.
1 change: 0 additions & 1 deletion inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

49 changes: 30 additions & 19 deletions inference/inference_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
# import pretrainedmodels
# import pretrainedmodels.utils as mutils
from pathlib import Path
import mlflow

from utils.mllog import MLlogger


torch.backends.cudnn.deterministic = True
Expand All @@ -41,8 +42,6 @@
home = str(Path.home())
IMAGENET_FOR_INFERENCE = '/home/cvds_lab/datasets/ILSVRC2012/'

mlflow.set_tracking_uri(os.path.join(home, 'mlruns_mxt'))

model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
Expand Down Expand Up @@ -100,11 +99,14 @@
parser.add_argument('--per_channel_quant_act', '-pcq_a', action='store_true', help='Per channel quantization of activations', default=False)
parser.add_argument('--bit_alloc_act', '-baa', action='store_true', help='Optimal bit allocation for each channel of activations', default=False)
parser.add_argument('--bit_alloc_weight', '-baw', action='store_true', help='Optimal bit allocation for each channel of weights', default=False)
parser.add_argument('--bit_alloc_rmode', '-bam', help='One of [round, ceil]', default='ceil')
parser.add_argument('--bit_alloc_rmode', '-bam', help='One of [round, ceil]', default='round')
parser.add_argument('--bit_alloc_prior', '-bap', help='One of [gaus, laplace]', default='gaus')
parser.add_argument('--bit_alloc_target_act', '-bata', type=float, help='Target value for bit allocation quota of activations', default=None)
parser.add_argument('--bit_alloc_target_weight', '-batw', type=float, help='Target value for bit allocation quota of weights', default=None)
parser.add_argument('--bias_corr_act', '-bca', action='store_true', help='Bias correction for activations', default=False)
parser.add_argument('--bias_corr_weight', '-bcw', action='store_true', help='Bias correction for weights', default=False)
parser.add_argument('--var_corr_weight', '-vcw', action='store_true', help='Variance correction for weights', default=False)
parser.add_argument('--measure_entropy', '-me', action='store_true', help='Measure entropy of activations', default=False)
parser.add_argument('--mlf_experiment', '-mlexp', help='Name of experiment', default=None)
args = parser.parse_args()

Expand All @@ -121,8 +123,12 @@
elif args.arch == 'inception_v3':
max_mse_order_id = ['conv5_activation', 'conv12_activation', 'conv1_activation', 'conv7_activation', 'conv4_activation', 'conv2_activation', 'conv14_activation', 'conv19_activation', 'conv10_activation', 'conv92_activation', 'conv21_activation', 'conv22_activation', 'conv9_activation', 'conv77_activation', 'conv16_activation', 'conv47_activation', 'conv48_activation', 'conv17_activation', 'conv58_activation', 'conv8_activation', 'conv55_activation', 'conv56_activation', 'conv40_activation', 'conv63_activation', 'conv15_activation', 'conv62_activation', 'conv84_activation', 'conv54_activation', 'conv57_activation', 'conv52_activation', 'conv65_activation', 'conv91_activation', 'conv76_activation', 'conv34_activation', 'conv51_activation', 'conv85_activation', 'conv53_activation', 'conv83_activation', 'conv35_activation', 'conv50_activation', 'conv46_activation', 'conv82_activation', 'conv61_activation', 'conv30_activation', 'conv37_activation', 'conv67_activation', 'conv75_activation', 'conv64_activation', 'conv29_activation', 'conv66_activation', 'conv44_activation', 'conv33_activation', 'conv43_activation', 'conv38_activation', 'conv45_activation', 'conv42_activation', 'conv23_activation', 'conv36_activation', 'conv60_activation', 'conv32_activation', 'conv41_activation', 'conv79_activation', 'conv6_activation', 'conv13_activation', 'conv78_activation', 'conv20_activation', 'conv73_activation', 'conv74_activation', 'conv80_activation', 'conv31_activation', 'conv27_activation', 'conv81_activation', 'conv88_activation', 'conv68_activation', 'conv28_activation', 'conv26_activation', 'conv89_activation', 'conv72_activation', 'conv93_activation', 'conv90_activation', 'conv94_activation', 'conv3_activation', 'conv24_activation', 'conv87_activation', 'conv18_activation', 'conv69_activation', 'conv59_activation', 'conv25_activation', 'conv49_activation', 'linear1_activation', 'conv39_activation', 'conv86_activation', 'conv11_activation', 'conv95_activation']

torch.manual_seed(12345)


class InferenceModel:
def __init__(self):
def __init__(self, ml_logger=None):
self.ml_logger = ml_logger
global args, best_prec1

if args.seed is not None:
Expand Down Expand Up @@ -260,12 +266,12 @@ def run(self):
print(elog)
else:
val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
if mlflow.active_run() is not None:
mlflow.log_metric('top1', val_prec1)
mlflow.log_metric('top5', val_prec5)
mlflow.log_metric('loss', val_loss)
return val_loss, val_prec1, val_prec5
if self.ml_logger is not None and self.ml_logger.mlflow.active_run() is not None:
self.ml_logger.mlflow.log_metric('top1', val_prec1)
self.ml_logger.mlflow.log_metric('top5', val_prec5)
self.ml_logger.mlflow.log_metric('loss', val_loss)

return val_loss, val_prec1, val_prec5


def validate(val_loader, model, criterion):
Expand All @@ -287,6 +293,8 @@ def validate(val_loader, model, criterion):
if (args.stats_mode == 'collect' and i*args.batch_size >= args.cal_set_size and (args.kld_threshold or args.aciq_cal)) or \
(args.subset is not None and i*args.batch_size >= args.subset):
break
if args.measure_entropy and i*args.batch_size >= args.subset:
break
# Uncomment to enable dump
# QM().disable()
# if i > 0:
Expand Down Expand Up @@ -333,7 +341,7 @@ def validate(val_loader, model, criterion):

return losses.avg, top1.avg, top5.avg

def get_params():
def get_params(logger=None):
qparams = {
'int': {
'clipping': args.clipping,
Expand All @@ -346,9 +354,13 @@ def get_params():
'bit_alloc_weight': args.bit_alloc_weight,
'bit_alloc_rmode': args.bit_alloc_rmode,
'bit_alloc_prior': args.bit_alloc_prior,
'bit_alloc_target_act': args.bit_alloc_target_act,
'bit_alloc_target_weight': args.bit_alloc_target_weight,
'bcorr_act': args.bias_corr_act,
'bcorr_weight': args.bias_corr_weight,
'vcorr_weight': args.var_corr_weight
'vcorr_weight': args.var_corr_weight,
'logger': logger,
'measure_entropy': args.measure_entropy
},
'qmanager':{
'rho_act': args.rho_act,
Expand All @@ -357,15 +369,14 @@ def get_params():
} # TODO: add params for bfloat
return qparams


if __name__ == '__main__':
if args.stats_mode != 'collect':
mlflow.set_experiment(args.arch if args.mlf_experiment is None else args.mlf_experiment)
with mlflow.start_run(run_name="{}_W{}_A{}".format(args.arch, args.qweight, args.qtype)):
params = vars(args)
for p in params:
mlflow.log_param(p, params[p])
with QM(args, get_params()):
im = InferenceModel()
experiment = args.arch if args.mlf_experiment is None else args.mlf_experiment
with MLlogger(os.path.join(home, 'mlruns_mxt'), experiment, args,
name_args=[args.arch, "W{}A{}".format(args.qweight, args.qtype)]) as ml_logger:
with QM(args, get_params(ml_logger)):
im = InferenceModel(ml_logger)
im.run()
else:
with QM(args, get_params()):
Expand Down
1 change: 0 additions & 1 deletion pytorch_quantizer/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torchvision
from pytorch_quantizer.quantization import qtypes
from utils.misc import Singleton
from utils import attacher
Expand All @@ -17,7 +18,6 @@
from pytorch_quantizer.clipping.clipping_manager import StatisticalClipper, RatioClipper
from pytorch_quantizer.quantization.qtypes.dummy_quantizer import DummyQuantizer


VERBOSE = True

class StatsMode(Enum):
Expand Down Expand Up @@ -66,10 +66,10 @@ def forward(self, input):
QMI().stats_manager.save_tensor_stats(out, 'activation_pooling', out_id)
elif QMI().stats_mode is StatsMode.use_stats:
# Quantize using statistics
out = QMI().quantize_instant(out, "activation_pooling", stat_id=out_id, verbose=QMI().verbose)
out = QMI().quantize_instant(out, out_id, "activation_pooling", stat_id=out_id, verbose=QMI().verbose)
else:
# No stats, quantize using actual values
out = QMI().quantize_instant(out, "activation_pooling", verbose=QMI().verbose)
out = QMI().quantize_instant(out, out_id, "activation_pooling", verbose=QMI().verbose)

return out

Expand Down Expand Up @@ -174,7 +174,7 @@ def forward(self, input):
QMI().stats_manager.save_tensor_stats(out, self.internal_name, activation_id)
elif QMI().stats_mode is StatsMode.use_stats:
# Quantize using statistics
out_q = QMI().quantize_instant(out, tag_act, stat_id=activation_id,
out_q = QMI().quantize_instant(out, activation_id, tag_act, stat_id=activation_id,
half_range=hasattr(self, 'before_relu'), verbose=QMI().verbose)
# print("%s: %d" % (activation_id, out.shape[2]*out.shape[3]))
if QMI().bcorr_act:
Expand Down Expand Up @@ -204,7 +204,7 @@ def forward(self, input):

else:
# No stats, quantize using actual values
out = QMI().quantize_instant(out, tag_act, half_range=hasattr(self, 'before_relu'), verbose=QMI().verbose)
out = QMI().quantize_instant(out, activation_id, tag_act, half_range=hasattr(self, 'before_relu'), verbose=QMI().verbose)

if QMI().measure_stats.enabled:
QMI().measure_stats.save_measure(out, activation_id)
Expand Down Expand Up @@ -236,13 +236,13 @@ def forward(self, input):
if QMI().stats_mode is StatsMode.collect_stats:
QMI().stats_manager.save_tensor_stats(out, tag_act, activation_id, force_global_min_max=('classifier' in tag_act))
elif QMI().stats_mode is StatsMode.use_stats:
out_q = QMI().quantize_instant(out, tag_act, stat_id=activation_id, half_range=half_range,
out_q = QMI().quantize_instant(out, activation_id, tag_act, stat_id=activation_id, half_range=half_range,
verbose=QMI().verbose)

out = out_q

else:
out = QMI().quantize_instant(out, tag_act, half_range=half_range, verbose=QMI().verbose)
out = QMI().quantize_instant(out, activation_id, tag_act, half_range=half_range, verbose=QMI().verbose)

if QMI().measure_stats.enabled:
QMI().measure_stats.save_measure(out, activation_id)
Expand Down Expand Up @@ -339,8 +339,8 @@ def createTruncationManager(self, args, qparams):

return op_manager

def quantize_instant(self, tensor, tag="", stat_id=None, half_range=False, override_att=None, verbose=False):
return self.op_manager.quantize_instant(tensor, tag, stat_id, half_range, override_att, verbose)
def quantize_instant(self, tensor, id, tag="", stat_id=None, half_range=False, override_att=None, verbose=False):
return self.op_manager.quantize_instant(tensor, id, tag, stat_id, half_range, override_att, verbose)

def set_8bit_list(self, ignore_ids):
self.op_manager.set_8bit_list(ignore_ids)
Expand All @@ -356,15 +356,19 @@ def quantize_model(self, model):
for n, m in model.named_modules():
weight_q = None
if isinstance(m, torch.nn.Conv2d):
if m.weight.shape[1] == 3:
# In case of inceptionV3 leave first and second layer at 8 bit
if isinstance(model, torchvision.models.Inception3) and \
(n == 'Conv2d_1a_3x3.conv' or n == 'Conv2d_2a_3x3.conv'):
weight_q = QMI().quantize_instant(m.weight, n + '.weight', "weight", override_att=('num_bits', 8), verbose=True)
elif m.weight.shape[1] == 3:
# first layer leave in 8 bit
weight_q = QMI().quantize_instant(m.weight, "weight", override_att=('num_bits', 8), verbose=True)
weight_q = QMI().quantize_instant(m.weight, n + '.weight', "weight", override_att=('num_bits', 8), verbose=True)
else:
weight_q = QMI().quantize_instant(m.weight, "weight", verbose=True)
weight_q = QMI().quantize_instant(m.weight, n + '.weight', "weight", verbose=True)

elif isinstance(m, torch.nn.Linear):
tag_weight = 'weight_classifier' if m.weight.shape[0] == 1000 else 'weight'
weight_q = QMI().quantize_instant(m.weight, tag_weight, verbose=True)
weight_q = QMI().quantize_instant(m.weight, n + '.weight', tag_weight, verbose=True)

if weight_q is not None:
if self.vcorr_weight or self.bcorr_weight:
Expand Down Expand Up @@ -408,6 +412,7 @@ def __fill_quantizers__(self, qtype, qparams, arch=None, qweight='int8'):
classifier_quantizer.pcq_a = False
classifier_quantizer.sm = StatisticManager
classifier_quantizer.stats_kind = 'max'
classifier_quantizer.measure_entropy = False
self.quantizers['activation_classifier'] = classifier_quantizer

if qweight == 'f32':
Expand All @@ -427,6 +432,7 @@ def __fill_quantizers__(self, qtype, qparams, arch=None, qweight='int8'):
weights_quantizer.kld = False
weights_quantizer.bit_alloc = False
weights_quantizer.stats_kind = 'max'
weights_quantizer.measure_entropy = False
self.quantizers['weight_classifier'] = weights_quantizer

bias_quantizer, _ = self.__load_quantizer__('int8', qparams)
Expand Down Expand Up @@ -466,6 +472,7 @@ def __fill_quantizers__(self, qtype, qparams, arch=None, qweight='int8'):
pooling_quantizer.clipping = 'no'
pooling_quantizer.kld = False
pooling_quantizer.bit_alloc = False
pooling_quantizer.measure_entropy = False
self.quantizers['activation_pooling'] = pooling_quantizer

def __init__(self, args, qparams):
Expand Down Expand Up @@ -539,27 +546,17 @@ def quantize_tensor(self, tensor, fprop=True, bprop=True):
fprop = self.activation_quantizer if fprop else None
return attacher.pytorch_attach(tensor, fprop, None)

def quantize_instant(self, tensor, tag="", stat_id=None, half_range=False, override_att=None, verbose=False):
def quantize_instant(self, tensor, id, tag="", stat_id=None, half_range=False, override_att=None, verbose=False):
# ignore quantization of first and last layer
ignore_cond = False
if stat_id is not None:
ignore_cond = np.array([l == stat_id for l in self.ignore_ids]).any()

# if self.fp32_clip:
# if ignore_cond:
# return self.activations_clipper(tensor, tag, stat_id) if self.rho_act is not None else tensor
# elif (tag == 'activation' or tag == 'activation_classifier' and tensor.shape[1] == 1000) or (tag == 'weight_classifier' and tensor.shape[0] == 1000):
# return tensor # Last linear layer. No clipping here
# elif tag == 'weight':
# return self.weights_clipper(tensor, tag) if self.rho_weight is not None else tensor
# else: # bias, pooling etc..
# return tensor
# else:
qtag = 'ignored' if ignore_cond else tag
q = self.get_quantizer(qtag)
q.half_range = half_range

if verbose:
print("Quantize {0:21} | Id - {1:18} | {2:} | {3:}".format(tag, str(stat_id), str(q), str(tensor.device)))

return q(tensor, tag, stat_id, override_att)
return q(tensor, id, tag, stat_id, override_att)
1 change: 1 addition & 0 deletions pytorch_quantizer/quantization/qtypes/dummy_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

class DummyQuantizer:
def __call__(self, tensor, tag="", stat_id=None, override_att=None):
return tensor
Expand Down
Loading

0 comments on commit c4a930d

Please sign in to comment.