From c4a930da13cd57e346ca275090f0d62ec440d659 Mon Sep 17 00:00:00 2001 From: submission2019 Date: Wed, 31 Jul 2019 14:37:21 +0300 Subject: [PATCH] update code with VLC --- inference/__init__.py | 1 - inference/inference_sim.py | 49 +-- pytorch_quantizer/quantization/__init__.py | 1 - .../inference_quantization_manager.py | 47 ++- .../quantization/qtypes/dummy_quantizer.py | 1 + .../quantization/qtypes/int_quantizer.py | 291 ++++++++++++++---- .../quantization/quantization_manager.py | 1 - utils/__init__.py | 1 - utils/absorb_bn.py | 1 + utils/attacher.py | 1 + utils/dataset.py | 1 - utils/dump_manager.py | 1 - utils/log.py | 1 - utils/mark_relu.py | 1 - utils/meters.py | 27 +- utils/misc.py | 1 + utils/model_naming.py | 1 + utils/monitor.py | 1 + utils/optim.py | 1 + utils/preprocess.py | 1 + 20 files changed, 311 insertions(+), 119 deletions(-) diff --git a/inference/__init__.py b/inference/__init__.py index 8b13789..e69de29 100644 --- a/inference/__init__.py +++ b/inference/__init__.py @@ -1 +0,0 @@ - diff --git a/inference/inference_sim.py b/inference/inference_sim.py index 57feae3..13369d0 100644 --- a/inference/inference_sim.py +++ b/inference/inference_sim.py @@ -32,7 +32,8 @@ # import pretrainedmodels # import pretrainedmodels.utils as mutils from pathlib import Path -import mlflow + +from utils.mllog import MLlogger torch.backends.cudnn.deterministic = True @@ -41,8 +42,6 @@ home = str(Path.home()) IMAGENET_FOR_INFERENCE = '/home/cvds_lab/datasets/ILSVRC2012/' -mlflow.set_tracking_uri(os.path.join(home, 'mlruns_mxt')) - model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) @@ -100,11 +99,14 @@ parser.add_argument('--per_channel_quant_act', '-pcq_a', action='store_true', help='Per channel quantization of activations', default=False) parser.add_argument('--bit_alloc_act', '-baa', action='store_true', help='Optimal bit allocation for each channel of activations', default=False) parser.add_argument('--bit_alloc_weight', '-baw', action='store_true', help='Optimal bit allocation for each channel of weights', default=False) -parser.add_argument('--bit_alloc_rmode', '-bam', help='One of [round, ceil]', default='ceil') +parser.add_argument('--bit_alloc_rmode', '-bam', help='One of [round, ceil]', default='round') parser.add_argument('--bit_alloc_prior', '-bap', help='One of [gaus, laplace]', default='gaus') +parser.add_argument('--bit_alloc_target_act', '-bata', type=float, help='Target value for bit allocation quota of activations', default=None) +parser.add_argument('--bit_alloc_target_weight', '-batw', type=float, help='Target value for bit allocation quota of weights', default=None) parser.add_argument('--bias_corr_act', '-bca', action='store_true', help='Bias correction for activations', default=False) parser.add_argument('--bias_corr_weight', '-bcw', action='store_true', help='Bias correction for weights', default=False) parser.add_argument('--var_corr_weight', '-vcw', action='store_true', help='Variance correction for weights', default=False) +parser.add_argument('--measure_entropy', '-me', action='store_true', help='Measure entropy of activations', default=False) parser.add_argument('--mlf_experiment', '-mlexp', help='Name of experiment', default=None) args = parser.parse_args() @@ -121,8 +123,12 @@ elif args.arch == 'inception_v3': max_mse_order_id = ['conv5_activation', 'conv12_activation', 'conv1_activation', 'conv7_activation', 'conv4_activation', 'conv2_activation', 'conv14_activation', 'conv19_activation', 'conv10_activation', 'conv92_activation', 'conv21_activation', 'conv22_activation', 'conv9_activation', 'conv77_activation', 'conv16_activation', 'conv47_activation', 'conv48_activation', 'conv17_activation', 'conv58_activation', 'conv8_activation', 'conv55_activation', 'conv56_activation', 'conv40_activation', 'conv63_activation', 'conv15_activation', 'conv62_activation', 'conv84_activation', 'conv54_activation', 'conv57_activation', 'conv52_activation', 'conv65_activation', 'conv91_activation', 'conv76_activation', 'conv34_activation', 'conv51_activation', 'conv85_activation', 'conv53_activation', 'conv83_activation', 'conv35_activation', 'conv50_activation', 'conv46_activation', 'conv82_activation', 'conv61_activation', 'conv30_activation', 'conv37_activation', 'conv67_activation', 'conv75_activation', 'conv64_activation', 'conv29_activation', 'conv66_activation', 'conv44_activation', 'conv33_activation', 'conv43_activation', 'conv38_activation', 'conv45_activation', 'conv42_activation', 'conv23_activation', 'conv36_activation', 'conv60_activation', 'conv32_activation', 'conv41_activation', 'conv79_activation', 'conv6_activation', 'conv13_activation', 'conv78_activation', 'conv20_activation', 'conv73_activation', 'conv74_activation', 'conv80_activation', 'conv31_activation', 'conv27_activation', 'conv81_activation', 'conv88_activation', 'conv68_activation', 'conv28_activation', 'conv26_activation', 'conv89_activation', 'conv72_activation', 'conv93_activation', 'conv90_activation', 'conv94_activation', 'conv3_activation', 'conv24_activation', 'conv87_activation', 'conv18_activation', 'conv69_activation', 'conv59_activation', 'conv25_activation', 'conv49_activation', 'linear1_activation', 'conv39_activation', 'conv86_activation', 'conv11_activation', 'conv95_activation'] +torch.manual_seed(12345) + + class InferenceModel: - def __init__(self): + def __init__(self, ml_logger=None): + self.ml_logger = ml_logger global args, best_prec1 if args.seed is not None: @@ -260,12 +266,12 @@ def run(self): print(elog) else: val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion) - if mlflow.active_run() is not None: - mlflow.log_metric('top1', val_prec1) - mlflow.log_metric('top5', val_prec5) - mlflow.log_metric('loss', val_loss) - return val_loss, val_prec1, val_prec5 + if self.ml_logger is not None and self.ml_logger.mlflow.active_run() is not None: + self.ml_logger.mlflow.log_metric('top1', val_prec1) + self.ml_logger.mlflow.log_metric('top5', val_prec5) + self.ml_logger.mlflow.log_metric('loss', val_loss) + return val_loss, val_prec1, val_prec5 def validate(val_loader, model, criterion): @@ -287,6 +293,8 @@ def validate(val_loader, model, criterion): if (args.stats_mode == 'collect' and i*args.batch_size >= args.cal_set_size and (args.kld_threshold or args.aciq_cal)) or \ (args.subset is not None and i*args.batch_size >= args.subset): break + if args.measure_entropy and i*args.batch_size >= args.subset: + break # Uncomment to enable dump # QM().disable() # if i > 0: @@ -333,7 +341,7 @@ def validate(val_loader, model, criterion): return losses.avg, top1.avg, top5.avg -def get_params(): +def get_params(logger=None): qparams = { 'int': { 'clipping': args.clipping, @@ -346,9 +354,13 @@ def get_params(): 'bit_alloc_weight': args.bit_alloc_weight, 'bit_alloc_rmode': args.bit_alloc_rmode, 'bit_alloc_prior': args.bit_alloc_prior, + 'bit_alloc_target_act': args.bit_alloc_target_act, + 'bit_alloc_target_weight': args.bit_alloc_target_weight, 'bcorr_act': args.bias_corr_act, 'bcorr_weight': args.bias_corr_weight, - 'vcorr_weight': args.var_corr_weight + 'vcorr_weight': args.var_corr_weight, + 'logger': logger, + 'measure_entropy': args.measure_entropy }, 'qmanager':{ 'rho_act': args.rho_act, @@ -357,15 +369,14 @@ def get_params(): } # TODO: add params for bfloat return qparams + if __name__ == '__main__': if args.stats_mode != 'collect': - mlflow.set_experiment(args.arch if args.mlf_experiment is None else args.mlf_experiment) - with mlflow.start_run(run_name="{}_W{}_A{}".format(args.arch, args.qweight, args.qtype)): - params = vars(args) - for p in params: - mlflow.log_param(p, params[p]) - with QM(args, get_params()): - im = InferenceModel() + experiment = args.arch if args.mlf_experiment is None else args.mlf_experiment + with MLlogger(os.path.join(home, 'mlruns_mxt'), experiment, args, + name_args=[args.arch, "W{}A{}".format(args.qweight, args.qtype)]) as ml_logger: + with QM(args, get_params(ml_logger)): + im = InferenceModel(ml_logger) im.run() else: with QM(args, get_params()): diff --git a/pytorch_quantizer/quantization/__init__.py b/pytorch_quantizer/quantization/__init__.py index 8b13789..e69de29 100644 --- a/pytorch_quantizer/quantization/__init__.py +++ b/pytorch_quantizer/quantization/__init__.py @@ -1 +0,0 @@ - diff --git a/pytorch_quantizer/quantization/inference/inference_quantization_manager.py b/pytorch_quantizer/quantization/inference/inference_quantization_manager.py index 94ef3f4..4efc4aa 100644 --- a/pytorch_quantizer/quantization/inference/inference_quantization_manager.py +++ b/pytorch_quantizer/quantization/inference/inference_quantization_manager.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torchvision from pytorch_quantizer.quantization import qtypes from utils.misc import Singleton from utils import attacher @@ -17,7 +18,6 @@ from pytorch_quantizer.clipping.clipping_manager import StatisticalClipper, RatioClipper from pytorch_quantizer.quantization.qtypes.dummy_quantizer import DummyQuantizer - VERBOSE = True class StatsMode(Enum): @@ -66,10 +66,10 @@ def forward(self, input): QMI().stats_manager.save_tensor_stats(out, 'activation_pooling', out_id) elif QMI().stats_mode is StatsMode.use_stats: # Quantize using statistics - out = QMI().quantize_instant(out, "activation_pooling", stat_id=out_id, verbose=QMI().verbose) + out = QMI().quantize_instant(out, out_id, "activation_pooling", stat_id=out_id, verbose=QMI().verbose) else: # No stats, quantize using actual values - out = QMI().quantize_instant(out, "activation_pooling", verbose=QMI().verbose) + out = QMI().quantize_instant(out, out_id, "activation_pooling", verbose=QMI().verbose) return out @@ -174,7 +174,7 @@ def forward(self, input): QMI().stats_manager.save_tensor_stats(out, self.internal_name, activation_id) elif QMI().stats_mode is StatsMode.use_stats: # Quantize using statistics - out_q = QMI().quantize_instant(out, tag_act, stat_id=activation_id, + out_q = QMI().quantize_instant(out, activation_id, tag_act, stat_id=activation_id, half_range=hasattr(self, 'before_relu'), verbose=QMI().verbose) # print("%s: %d" % (activation_id, out.shape[2]*out.shape[3])) if QMI().bcorr_act: @@ -204,7 +204,7 @@ def forward(self, input): else: # No stats, quantize using actual values - out = QMI().quantize_instant(out, tag_act, half_range=hasattr(self, 'before_relu'), verbose=QMI().verbose) + out = QMI().quantize_instant(out, activation_id, tag_act, half_range=hasattr(self, 'before_relu'), verbose=QMI().verbose) if QMI().measure_stats.enabled: QMI().measure_stats.save_measure(out, activation_id) @@ -236,13 +236,13 @@ def forward(self, input): if QMI().stats_mode is StatsMode.collect_stats: QMI().stats_manager.save_tensor_stats(out, tag_act, activation_id, force_global_min_max=('classifier' in tag_act)) elif QMI().stats_mode is StatsMode.use_stats: - out_q = QMI().quantize_instant(out, tag_act, stat_id=activation_id, half_range=half_range, + out_q = QMI().quantize_instant(out, activation_id, tag_act, stat_id=activation_id, half_range=half_range, verbose=QMI().verbose) out = out_q else: - out = QMI().quantize_instant(out, tag_act, half_range=half_range, verbose=QMI().verbose) + out = QMI().quantize_instant(out, activation_id, tag_act, half_range=half_range, verbose=QMI().verbose) if QMI().measure_stats.enabled: QMI().measure_stats.save_measure(out, activation_id) @@ -339,8 +339,8 @@ def createTruncationManager(self, args, qparams): return op_manager - def quantize_instant(self, tensor, tag="", stat_id=None, half_range=False, override_att=None, verbose=False): - return self.op_manager.quantize_instant(tensor, tag, stat_id, half_range, override_att, verbose) + def quantize_instant(self, tensor, id, tag="", stat_id=None, half_range=False, override_att=None, verbose=False): + return self.op_manager.quantize_instant(tensor, id, tag, stat_id, half_range, override_att, verbose) def set_8bit_list(self, ignore_ids): self.op_manager.set_8bit_list(ignore_ids) @@ -356,15 +356,19 @@ def quantize_model(self, model): for n, m in model.named_modules(): weight_q = None if isinstance(m, torch.nn.Conv2d): - if m.weight.shape[1] == 3: + # In case of inceptionV3 leave first and second layer at 8 bit + if isinstance(model, torchvision.models.Inception3) and \ + (n == 'Conv2d_1a_3x3.conv' or n == 'Conv2d_2a_3x3.conv'): + weight_q = QMI().quantize_instant(m.weight, n + '.weight', "weight", override_att=('num_bits', 8), verbose=True) + elif m.weight.shape[1] == 3: # first layer leave in 8 bit - weight_q = QMI().quantize_instant(m.weight, "weight", override_att=('num_bits', 8), verbose=True) + weight_q = QMI().quantize_instant(m.weight, n + '.weight', "weight", override_att=('num_bits', 8), verbose=True) else: - weight_q = QMI().quantize_instant(m.weight, "weight", verbose=True) + weight_q = QMI().quantize_instant(m.weight, n + '.weight', "weight", verbose=True) elif isinstance(m, torch.nn.Linear): tag_weight = 'weight_classifier' if m.weight.shape[0] == 1000 else 'weight' - weight_q = QMI().quantize_instant(m.weight, tag_weight, verbose=True) + weight_q = QMI().quantize_instant(m.weight, n + '.weight', tag_weight, verbose=True) if weight_q is not None: if self.vcorr_weight or self.bcorr_weight: @@ -408,6 +412,7 @@ def __fill_quantizers__(self, qtype, qparams, arch=None, qweight='int8'): classifier_quantizer.pcq_a = False classifier_quantizer.sm = StatisticManager classifier_quantizer.stats_kind = 'max' + classifier_quantizer.measure_entropy = False self.quantizers['activation_classifier'] = classifier_quantizer if qweight == 'f32': @@ -427,6 +432,7 @@ def __fill_quantizers__(self, qtype, qparams, arch=None, qweight='int8'): weights_quantizer.kld = False weights_quantizer.bit_alloc = False weights_quantizer.stats_kind = 'max' + weights_quantizer.measure_entropy = False self.quantizers['weight_classifier'] = weights_quantizer bias_quantizer, _ = self.__load_quantizer__('int8', qparams) @@ -466,6 +472,7 @@ def __fill_quantizers__(self, qtype, qparams, arch=None, qweight='int8'): pooling_quantizer.clipping = 'no' pooling_quantizer.kld = False pooling_quantizer.bit_alloc = False + pooling_quantizer.measure_entropy = False self.quantizers['activation_pooling'] = pooling_quantizer def __init__(self, args, qparams): @@ -539,22 +546,12 @@ def quantize_tensor(self, tensor, fprop=True, bprop=True): fprop = self.activation_quantizer if fprop else None return attacher.pytorch_attach(tensor, fprop, None) - def quantize_instant(self, tensor, tag="", stat_id=None, half_range=False, override_att=None, verbose=False): + def quantize_instant(self, tensor, id, tag="", stat_id=None, half_range=False, override_att=None, verbose=False): # ignore quantization of first and last layer ignore_cond = False if stat_id is not None: ignore_cond = np.array([l == stat_id for l in self.ignore_ids]).any() - # if self.fp32_clip: - # if ignore_cond: - # return self.activations_clipper(tensor, tag, stat_id) if self.rho_act is not None else tensor - # elif (tag == 'activation' or tag == 'activation_classifier' and tensor.shape[1] == 1000) or (tag == 'weight_classifier' and tensor.shape[0] == 1000): - # return tensor # Last linear layer. No clipping here - # elif tag == 'weight': - # return self.weights_clipper(tensor, tag) if self.rho_weight is not None else tensor - # else: # bias, pooling etc.. - # return tensor - # else: qtag = 'ignored' if ignore_cond else tag q = self.get_quantizer(qtag) q.half_range = half_range @@ -562,4 +559,4 @@ def quantize_instant(self, tensor, tag="", stat_id=None, half_range=False, overr if verbose: print("Quantize {0:21} | Id - {1:18} | {2:} | {3:}".format(tag, str(stat_id), str(q), str(tensor.device))) - return q(tensor, tag, stat_id, override_att) + return q(tensor, id, tag, stat_id, override_att) diff --git a/pytorch_quantizer/quantization/qtypes/dummy_quantizer.py b/pytorch_quantizer/quantization/qtypes/dummy_quantizer.py index 8a40a57..2c62162 100644 --- a/pytorch_quantizer/quantization/qtypes/dummy_quantizer.py +++ b/pytorch_quantizer/quantization/qtypes/dummy_quantizer.py @@ -1,3 +1,4 @@ + class DummyQuantizer: def __call__(self, tensor, tag="", stat_id=None, override_att=None): return tensor diff --git a/pytorch_quantizer/quantization/qtypes/int_quantizer.py b/pytorch_quantizer/quantization/qtypes/int_quantizer.py index f9417ba..c94a535 100644 --- a/pytorch_quantizer/quantization/qtypes/int_quantizer.py +++ b/pytorch_quantizer/quantization/qtypes/int_quantizer.py @@ -6,7 +6,8 @@ from utils.monitor import Monitor from pytorch_quantizer.quantization.inference.statistic_manager import StatisticManager from pytorch_quantizer.quantization.inference.statistic_manager_perchannel import StatisticManagerPerChannel - +from utils.entropy import shannon_entropy, most_requent_value_compression +import scipy.optimize as opt # Alpha coeficients for for gaussian clipping # [1.71063519 2.15159277 2.55913646 2.93620062 3.28691474 3.6151146 3.92403714] @@ -29,7 +30,29 @@ def to_numpy(tensor): else: return tensor + +def laplace_prior_mse(b, alpha, omega): + return 2 * (b ** 2) * np.exp(-alpha / b) + (alpha ** 2 / (3 * omega**2)) + +def half_laplace_prior_mse(b, alpha, omega): + return (b ** 2) * np.exp(-alpha / b) + (alpha ** 2 / (24 * omega**2)) + +# Numpy code to find optimal alpha for real omega +resolution = 20 +omega_table = np.concatenate([np.linspace(0.01, 0.1, resolution, endpoint=False), + np.linspace(0.1, 1, resolution, endpoint=False), + np.linspace(1, 10, resolution, endpoint=False), + np.linspace(10, 100, resolution, endpoint=False), + np.linspace(100, 1000, resolution, endpoint=False)]) + +alpha_table = np.array([opt.minimize_scalar(lambda x: laplace_prior_mse(b=1, alpha=x, omega=w)).x for w in omega_table]) +alpha_table = np.concatenate([[0], alpha_table]) + +omega_table = np.concatenate([[0], omega_table]) + count = 0 + + class IntQuantizer(Function): def __init__(self, size, params): self.num_bits = size @@ -49,19 +72,23 @@ def __init__(self, size, params): self.vcorr_weight = params['vcorr_weight'] self.bit_alloc_round = params['bit_alloc_rmode'] == 'round' self.bit_alloc_prior = params['bit_alloc_prior'] + self.bit_alloc_target_act = params['bit_alloc_target_act'] if params['bit_alloc_target_act'] is not None else self.num_bits + self.bit_alloc_target_weight = params['bit_alloc_target_weight'] if params['bit_alloc_target_weight'] is not None else self.num_bits + self.measure_entropy = params['measure_entropy'] + self.logger = params['logger'] - self.alpha_gaus = {1 : 1.24, 2: 1.71, 3: 2.15, 4: 2.55, 5: 2.93, 6: 3.28, 7: 3.61, 8: 3.92} - self.alpha_gaus_positive = {1 : 1.71, 2: 2.15, 3: 2.55, 4: 2.93, 5: 3.28, 6: 3.61, 7: 3.92, 8: 4.2} + self.alpha_gaus = {1: 1.24, 2: 1.71, 3: 2.15, 4: 2.55, 5: 2.93, 6: 3.28, 7: 3.61, 8: 3.92} + self.alpha_gaus_positive = {1: 1.71, 2: 2.15, 3: 2.55, 4: 2.93, 5: 3.28, 6: 3.61, 7: 3.92, 8: 4.2} - self.alpha_laplace = {0 : 1.05, 1 : 1.86, 2: 2.83, 3: 3.89, 4: 5.03, 5: 6.2, 6: 7.41, 7: 8.64, 8: 9.89} - self.alpha_laplace_positive = {0 : 1.86, 1 : 2.83, 2: 3.89, 3: 5.02, 4: 6.2, 5: 7.41, 6: 8.64, 7: 9.89, 8: 11.16} + self.alpha_laplace = {0: 1.05, 1: 1.86, 2: 2.83, 3: 3.89, 4: 5.03, 5: 6.2, 6: 7.41, 7: 8.64, 8: 9.89} + self.alpha_laplace_positive = {0: 1.86, 1: 2.83, 2: 3.89, 3: 5.02, 4: 6.2, 5: 7.41, 6: 8.64, 7: 9.89, 8: 11.16} self.gaussian_const = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 0.5) self.sm = StatisticManagerPerChannel if params['pcq_act'] else StatisticManager self.force_positive = False self.half_range = False - def __call__(self, tensor, tag="", stat_id=None, override_att=None): + def __call__(self, tensor, id, tag="", stat_id=None, override_att=None): if override_att is not None: orig_att = getattr(self, override_att[0]) setattr(self, override_att[0], override_att[1]) @@ -69,13 +96,17 @@ def __call__(self, tensor, tag="", stat_id=None, override_att=None): res = self.gemmlowpKldQuantize(tensor, tag, stat_id=stat_id) elif self.clipping != 'no': # print("clipping %s: %d" % (tag, self.num_bits)) - res = self.gemmlowpClippingQuantize(tensor, tag, stat_id=stat_id, clip_type=self.clipping) + # TODO: select between mid-tread and gemmlowp by cmd flag + res = self.gemmlowpClippingQuantize(tensor, id, tag, stat_id=stat_id, clip_type=self.clipping) + # res = self.mid_tread_quantize_activation(tensor, id) elif self.pcq_w: # print("pcq_w %s: %d" % (tag, self.num_bits)) - res = self.gemmlowpQuantizeWeightsPerChannel(tensor) + res = self.gemmlowpQuantizeWeightsPerChannel(tensor, id) + # res = self.mid_tread_quantize_weights_per_channel(tensor, id) elif self.pcq_a and len(tensor.shape) > 3 and (tensor.shape[2] > 1 or tensor.shape[3] > 1): # print("pcq_a %s: %d" % (tag, self.num_bits)) - res = self.gemmlowpQuantizeActivationPerChannel(tensor, tag, stat_id=stat_id) + res = self.gemmlowpQuantizeActivationPerChannel(tensor, id, tag, stat_id=stat_id) + # res = self.mid_tread_quantize_activation_per_channel(tensor, id) else: # print("no clipping %s: %d" % (tag, self.num_bits)) res = self.gemmlowpMinMaxQuantize(tensor, tag, stat_id=stat_id) @@ -85,8 +116,107 @@ def __call__(self, tensor, tag="", stat_id=None, override_att=None): return res def __repr__(self): - return 'IntQuantizer - [bits: {}, clipping: {}, bit_alloc_act: {}, bit_alloc_weight: {}, pcq_w: {}, pcq_a: {}, bcorr_act: {}, bcorr_weight: {}, vcorr_weight: {}, kind: {}]'\ - .format(self.num_bits, self.clipping, self.bit_alloc_act, self.bit_alloc_weight, self.pcq_w, self.pcq_a, self.bcorr_act, self.bcorr_weight, self.vcorr_weight, self.stats_kind) + return 'IntQuantizer - [bits: {}, clipping: {}, bit_alloc_act: {}, bit_alloc_weight: {}, bit_alloc_round: {}, pcq_w: {}, pcq_a: {}, bcorr_act: {}, bcorr_weight: {}, vcorr_weight: {}, kind: {}]'\ + .format(self.num_bits, self.clipping, self.bit_alloc_act, self.bit_alloc_weight, self.bit_alloc_round, self.pcq_w, self.pcq_a, self.bcorr_act, self.bcorr_weight, self.vcorr_weight, self.stats_kind) + + @staticmethod + def get_omega(sigma, target_bins): + B = len(sigma) * target_bins + + # Calculate bit allocation + p = sigma ** (2./3) + omega = (B * p) / p.sum() + return omega + + @staticmethod + def get_alpha_mult(omega, sym=True): + omega = omega.cpu().numpy() + if not sym: + omega *= 2 + i = omega_table.searchsorted(omega) + inc = (alpha_table[i] - alpha_table[i - 1]) / (omega_table[i] - omega_table[i - 1]) + alpha = alpha_table[i] - inc * (omega_table[i] - omega) + return alpha + + def mid_tread_quantize_weights_per_channel(self, tensor, id): + # Assume weights with dimensions [OFM,IFM,K1,K2] + t = tensor.view(tensor.shape[0], -1) + + tq, entropy = self.mid_tread_quantization(t, id, self.bit_alloc_target_weight, clip=False, sym=True) + if entropy is not None and self.logger is not None: + self.logger.log_metric(id + '.entropy', entropy.item(), step='auto', meterId='avg.entropy.weight', + weight=tq.numel()) + + return tq.view(tensor.shape) + + def mid_tread_quantize_activation(self, tensor, id): + # Assume [N, C, H, W] or [N, M] + if self.pcq_a and len(tensor.shape) > 3 and (tensor.shape[2] > 1 or tensor.shape[3] > 1): + # scale per channel + out = self.mid_tread_quantize_activation_per_channel(tensor, id) + else: + # scale per tensor + symmetric = not (self.force_positive or self.half_range) + out, entropy = self.mid_tread_quantization(tensor.view(1, -1), id, self.bit_alloc_target_act, clip=True, sym=symmetric) + + return out.view(tensor.shape) + + def mid_tread_quantize_activation_per_channel(self, tensor, id): + N, C, H, W = tensor.shape # N x C x H x W + t = tensor.detach().transpose(0, 1).contiguous() # C x N x H x W + t = t.view(t.shape[0], -1) + + symmetric = not (self.force_positive or self.half_range) + tq, entropy = self.mid_tread_quantization(t, id, self.bit_alloc_target_act, clip=True, sym=symmetric) + + if entropy is not None and self.logger is not None: + self.logger.log_metric(id + '.entropy', entropy.item(), step='auto', meterId='avg.entropy.act', + weight=tensor.numel()) + + output = tq.view(C, N, H, W).transpose(0, 1).contiguous() # N x C x H x W + return output.view(tensor.shape) + + def mid_tread_quantization(self, tensor, id, target, clip=False, sym=True): + std = tensor.std(-1) + omega = self.get_omega(std, target_bins=(2**target)).round() + + if clip: + alpha_mult = tensor.new_tensor(self.get_alpha_mult(omega, sym=sym)) + mu = tensor.mean(dim=-1) + b = torch.mean(torch.abs(tensor - mu.unsqueeze(-1)), dim=-1) + + rng = (2 * alpha_mult * b) if sym else (torch.max(mu, mu.new_tensor([0.])) + alpha_mult * b) + else: + rng = (tensor.max(-1)[0] - tensor.min(-1)[0]) if sym else tensor.max(-1)[0] + + Delta = torch.where(omega > 0, rng / omega, + tensor.new_tensor([np.finfo(np.float32).max])) + + # quantize + out = tensor / Delta.unsqueeze(-1) + out.round_() + + # clamp + if clip: + # Centralize quantization range around mean and make it non-negative for asymetric case + mu_q = mu / Delta if sym else torch.max(mu, mu.new_tensor([0.])) / Delta + c_max = mu_q + (omega / 2 if sym else omega) + c_min = ((mu_q - omega / 2) if sym else tensor.new_tensor([0])) + + # In practice all the parameters Delta, omega, c_max, c_min can be pre-calculated based on statistics + out = torch.min(out, c_max.unsqueeze(-1)) + out = torch.max(out, c_min.unsqueeze(-1)) + + if self.measure_entropy: + entropy = shannon_entropy(out, handle_negative=True) + # workaround for out of memory issue + torch.cuda.empty_cache() + else: + entropy = None + + # dequantize + out.mul_(Delta.unsqueeze(-1)) + return out, entropy def get_alpha_laplace(self, tensor, stat_id=None, kind='mean', per_channel=False): if stat_id is not None: @@ -107,13 +237,14 @@ def get_alpha_laplace(self, tensor, stat_id=None, kind='mean', per_channel=False std = self.__act_stats_perchannel__(tensor, [prior], avg_over_batch=False)[prior] else: std = self.__act_stats__(tensor, [prior], avg_over_batch=False)[prior] - bit_alloc = self.get_bits_alloc(std, self.num_bits, self.bit_alloc_round) + + bit_alloc = self.get_bits_alloc_fixed_target(std, self.bit_alloc_target_act, self.bit_alloc_round) aciq_factor = np.array([(self.alpha_laplace_positive[nbit.item()] if (self.force_positive or self.half_range) else self.alpha_laplace[nbit.item()]) for nbit in bit_alloc]) aciq_factor = to_cuda(aciq_factor, tensor.device) else: aciq_factor = (self.alpha_laplace_positive[self.num_bits] if (self.force_positive or self.half_range) else self.alpha_laplace[self.num_bits]) - return b * aciq_factor + return to_cuda(b, tensor.device) * aciq_factor def get_alpha_gaus(self, tensor, tag, stat_id=None, per_channel=False): if stat_id is not None: @@ -126,6 +257,17 @@ def get_alpha_gaus(self, tensor, tag, stat_id=None, per_channel=False): return std * (self.alpha_gaus_positive[self.num_bits] if (self.force_positive or self.half_range) else self.alpha_gaus[self.num_bits]) + def get_alpha_pstd(self, tensor, p, tag, stat_id=None, per_channel=False): + if stat_id is not None: + std = self.sm().get_tensor_stat(stat_id, 'std', 'mean') + else: + if per_channel: + std = self.__act_stats_perchannel__(tensor, ['std'], avg_over_batch=False)['std'] + else: + std = self.__act_stats__(tensor, ['std'], avg_over_batch=False)['std'] + + return p * std + def get_alpha_exp(self, tensor, stat_id=None, per_channel=False): if stat_id is not None: mean_abs = self.sm().get_tensor_stat(stat_id, 'mean') @@ -156,6 +298,9 @@ def get_alpha(self, tensor, tag="", stat_id=None, clip_type='laplace', per_chann alpha = self.get_alpha_laplace(tensor, stat_id, per_channel=per_channel) # laplace clipping elif clip_type == 'gaus': alpha = self.get_alpha_gaus(tensor, tag, stat_id, per_channel=per_channel) # gaussian clipping + elif 'std' in clip_type: + p = float(clip_type.replace('std', '')) + alpha = self.get_alpha_pstd(tensor, p, tag, stat_id, per_channel=per_channel) # 2std clipping elif clip_type == 'mix': mse_laplace = self.sm().get_tensor_stat(stat_id, 'mse_laplace', 'mean') mse_gaus = self.sm().get_tensor_stat(stat_id, 'mse_gaus', 'mean') @@ -173,8 +318,7 @@ def get_alpha(self, tensor, tag="", stat_id=None, clip_type='laplace', per_chann return alpha - - def gemmlowpClippingQuantize(self, tensor, tag="", stat_id=None, clip_type='laplace'): + def gemmlowpClippingQuantize(self, tensor, id, tag="", stat_id=None, clip_type='laplace'): if stat_id is not None: min_value = self.sm().get_tensor_stat(stat_id, 'min', 'mean') max_value = self.sm().get_tensor_stat(stat_id, 'max', 'mean') @@ -199,7 +343,7 @@ def gemmlowpClippingQuantize(self, tensor, tag="", stat_id=None, clip_type='lapl min_value = to_cuda(min_value, tensor.device) range = to_cuda(range, tensor.device) max_ = min_value + range - res = self.gemmlowpQuantizeActivationPerChannel(tensor.contiguous(), tag, stat_id, min_=min_value, max_=max_) + res = self.gemmlowpQuantizeActivationPerChannel(tensor.contiguous(), id, tag, stat_id, min_=min_value, max_=max_) else: alpha = self.get_alpha(tensor, tag, stat_id, clip_type, per_channel=False) max_value = float(max_value); min_value = float(min_value); mean = float(mean); alpha = float(alpha) @@ -230,17 +374,33 @@ def gemmlowpMinMaxQuantize(self, tensor, tag="", stat_id=None): @staticmethod def get_bits_alloc(alpha, num_bits, round=False): - # Quota assuming 4 bit target B = len(alpha) * 2 ** num_bits # Calculate bit allocation - p = alpha ** (2 / 3) + p = alpha ** (2. / 3) bin_alloc = (B * p) / p.sum() - bin_alloc[bin_alloc < 1] = 2 bit_alloc = torch.round(torch.log2(bin_alloc)) if round else torch.ceil(torch.log2(bin_alloc)) + bit_alloc[bit_alloc < 0] = 0 + bit_alloc[bit_alloc > 8] = 8 + return bit_alloc + + @staticmethod + def get_bits_alloc_fixed_target(alpha, num_bits, round=False): + eps = 0.01 + goal_bits = num_bits + target_bits = goal_bits + delta = 1. + iter = 0 + max_iter = 10 + while abs(2 * delta) > eps and iter < max_iter: + iter += 1 + bit_alloc = IntQuantizer.get_bits_alloc(alpha, num_bits=target_bits, round=round) + delta = (goal_bits - bit_alloc.mean()) / 2 + target_bits += delta.item() + return bit_alloc - def gemmlowpQuantizeActivationPerChannel(self, tensor, tag="", stat_id=None, min_=None, max_=None): + def gemmlowpQuantizeActivationPerChannel(self, tensor, id, tag="", stat_id=None, min_=None, max_=None): if min_ is None: if self.force_positive or self.half_range: min_ = 0 # np.zeros(min_.shape) @@ -269,15 +429,22 @@ def gemmlowpQuantizeActivationPerChannel(self, tensor, tag="", stat_id=None, min else: alpha = self.__act_stats_perchannel__(tensor, [prior], avg_over_batch=False)[prior] - bit_alloc = self.get_bits_alloc(alpha, self.num_bits, self.bit_alloc_round) + bit_alloc = self.get_bits_alloc_fixed_target(alpha, self.bit_alloc_target_act, self.bit_alloc_round) else: bit_alloc = None - output = self.__gemmlowpQuantize1__(t, max_ - min_, min_, bit_alloc=bit_alloc) + if self.measure_entropy: + output, entropy = self.__gemmlowpQuantize1__(t, max_ - min_, min_, bit_alloc=bit_alloc, measure_entropy=True) + if self.logger is not None: + self.logger.log_metric(id + '.entropy', entropy.item(), step='auto', meterId='avg.entropy.act', weight=output.numel()) + else: + output = self.__gemmlowpQuantize1__(t, max_ - min_, min_, bit_alloc=bit_alloc, + measure_entropy=self.measure_entropy) + output = output.view(C, N, H, W).transpose(0, 1).contiguous() # N x C x H x W return output.view(tensor.shape) - def gemmlowpQuantizeWeightsPerChannel(self, tensor, min_=None, max_=None): + def gemmlowpQuantizeWeightsPerChannel(self, tensor, id, min_=None, max_=None): # Assume weights with dimensions [OFM,IFM,K1,K2] t = tensor.view(tensor.shape[0], -1) @@ -289,11 +456,16 @@ def gemmlowpQuantizeWeightsPerChannel(self, tensor, min_=None, max_=None): if self.bit_alloc_weight and self.num_bits <= 4: alpha = t.std(-1) - bit_alloc = self.get_bits_alloc(alpha, self.num_bits, self.bit_alloc_round) + bit_alloc = self.get_bits_alloc_fixed_target(alpha, self.bit_alloc_target_weight, self.bit_alloc_round) else: bit_alloc = None - output = self.__gemmlowpQuantize1__(t, max_ - min_, min_, bit_alloc=bit_alloc) + if self.measure_entropy: + output, entropy = self.__gemmlowpQuantize1__(t, max_ - min_, min_, bit_alloc=bit_alloc, measure_entropy=True) + if self.logger is not None: + self.logger.log_metric(id + '.entropy', entropy.item(), step='auto', meterId='avg.entropy.weight', weight=output.numel()) + else: + output = self.__gemmlowpQuantize1__(t, max_ - min_, min_, bit_alloc=bit_alloc) return output.view(tensor.shape) @@ -307,25 +479,24 @@ def gemmlowpKldQuantize(self, tensor, tag="", stat_id=None): return self.__gemmlowpQuantize__(tensor, range, offset) - def symlowpQuantize(self, tensor): maxabs = torch.max(tensor.detach().abs()) return self.__symlowpQuantize__(tensor, maxabs) - @staticmethod - def mse_laplace(b, alpha, num_bits): - return 2 * (b ** 2) * np.exp(-alpha / b) + ((alpha ** 2) / (3 * 2 ** (2 * num_bits))) - - @staticmethod - def mse_exponential(mean_abs, alpha, num_bits): - return 2 * (mean_abs ** 2) * np.exp(-alpha / mean_abs) + ((alpha ** 2) / (3 * 2 ** (2 * num_bits))) - - @staticmethod - def mse_gaus(sigma, alpha, num_bits): - clipping_err = (sigma ** 2 + (alpha ** 2)) * (1 - math.erf(alpha / (sigma * np.sqrt(2.0)))) - \ - np.sqrt(2.0 / np.pi) * alpha * sigma * (np.e ** ((-1) * (0.5 * (alpha ** 2)) / sigma ** 2)) - quant_err = (alpha ** 2) / (3 * (2 ** (2 * num_bits))) - return clipping_err + quant_err + # @staticmethod + # def mse_laplace(b, alpha, num_bits): + # return 2 * (b ** 2) * np.exp(-alpha / b) + ((alpha ** 2) / (3 * 2 ** (2 * num_bits))) + # + # @staticmethod + # def mse_exponential(mean_abs, alpha, num_bits): + # return 2 * (mean_abs ** 2) * np.exp(-alpha / mean_abs) + ((alpha ** 2) / (3 * 2 ** (2 * num_bits))) + # + # @staticmethod + # def mse_gaus(sigma, alpha, num_bits): + # clipping_err = (sigma ** 2 + (alpha ** 2)) * (1 - math.erf(alpha / (sigma * np.sqrt(2.0)))) - \ + # np.sqrt(2.0 / np.pi) * alpha * sigma * (np.e ** ((-1) * (0.5 * (alpha ** 2)) / sigma ** 2)) + # quant_err = (alpha ** 2) / (3 * (2 ** (2 * num_bits))) + # return clipping_err + quant_err @staticmethod def __act_stats__(tensor, stats, avg_over_batch=False): @@ -377,36 +548,16 @@ def __act_stats_perchannel__(tensor, stats, avg_over_batch=False): return stats_dict - def __clip_and_mse_mesure(self, tensor, tag, stat_id, clip_type, max_value, min_value, mean, std, b): - if clip_type == 'laplace': - alpha = self.get_alpha_laplace(tensor, stat_id) # laplace clipping - mse_est = IntQuantizer.mse_laplace(b, alpha, self.num_bits) - elif clip_type == 'gaus': - alpha = self.get_alpha_gaus(tensor, tag, stat_id) # gaussian clipping - mse_est = IntQuantizer.mse_gaus(std, alpha, self.num_bits) - elif clip_type == 'exp': - alpha = self.get_alpha_exp(tensor, stat_id) # exponential clipping - mse_est = -1 - else: # no clipping - alpha = (max_value - min_value)/2 - mse_est = -1 - - delta, min_value = self.alpha2DeltaOffset(alpha, max_value, min_value, mean) - res = self.__gemmlowpQuantize__(tensor.contiguous(), delta, min_value) - mse = torch.mean((tensor - res)**2) - del res - return mse, mse_est - - def __gemmlowpQuantize1__(self, tensor, delta, offset, bit_alloc=None): + def __gemmlowpQuantize1__(self, tensor, delta, offset, bit_alloc=None, measure_entropy=False): qmin = 0. if bit_alloc is None: qmax = 2.**self.num_bits - 1. + scale = (delta) / (qmax - qmin) else: qmax = 2.**bit_alloc - 1. - #import pdb; pdb.set_trace() - scale = (delta) / (qmax - qmin) + scale = torch.where(qmax > 0, (delta) / (qmax - qmin), torch.tensor([0.]).to(tensor.device)) - scale = torch.max(scale, torch.tensor([1e-8]).to(scale.device)) + scale = torch.max(scale, torch.tensor([1e-8]).to(tensor.device)) output = tensor.detach() if self.enforce_true_zero: @@ -426,6 +577,10 @@ def __gemmlowpQuantize1__(self, tensor, delta, offset, bit_alloc=None): output = torch.where(output.gt(qmax), qmax, output) output.clamp_(qmin).round_() + if measure_entropy: + entropy = shannon_entropy(output.int()) + # entropy = most_requent_value_compression(output.int()) + if self.enforce_true_zero: output = torch.add(output, -zero_point.unsqueeze(-1)) output = torch.mul(output, scale.unsqueeze(-1)) # dequantize @@ -433,7 +588,13 @@ def __gemmlowpQuantize1__(self, tensor, delta, offset, bit_alloc=None): output = torch.mul(output, scale.unsqueeze(-1)) output = torch.add(output, offset.unsqueeze(-1)) # dequantize - return output.view(tensor.shape) + # workaround for out of memory issue + torch.cuda.empty_cache() + + if measure_entropy: + return output.view(tensor.shape), entropy + else: + return output.view(tensor.shape) def __gemmlowpQuantize__(self, tensor, delta, offset): if self.stochastic: diff --git a/pytorch_quantizer/quantization/quantization_manager.py b/pytorch_quantizer/quantization/quantization_manager.py index de0f5c3..1911672 100644 --- a/pytorch_quantizer/quantization/quantization_manager.py +++ b/pytorch_quantizer/quantization/quantization_manager.py @@ -5,7 +5,6 @@ from utils.monitor import Monitor import abc - INFERENCE_ONLY = False class QuantizationManagerBase(metaclass=Singleton): diff --git a/utils/__init__.py b/utils/__init__.py index 8b13789..e69de29 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1 +0,0 @@ - diff --git a/utils/absorb_bn.py b/utils/absorb_bn.py index 8894cfb..2d7f75e 100644 --- a/utils/absorb_bn.py +++ b/utils/absorb_bn.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn + def absorb_bn(module, bn_module): w = module.weight.data if module.bias is None: diff --git a/utils/attacher.py b/utils/attacher.py index 80e0ef8..40f8b58 100644 --- a/utils/attacher.py +++ b/utils/attacher.py @@ -1,6 +1,7 @@ from torch.autograd import Function # f is any callable object + # attacher to forward class attach_to_forward_class(Function): @staticmethod diff --git a/utils/dataset.py b/utils/dataset.py index b56399f..ec3b22b 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -2,7 +2,6 @@ from torch.utils.data import Dataset from numpy.random import choice - class RandomSamplerReplacment(torch.utils.data.sampler.Sampler): """Samples elements randomly, with replacement. Arguments: diff --git a/utils/dump_manager.py b/utils/dump_manager.py index 9e4beec..d559aea 100644 --- a/utils/dump_manager.py +++ b/utils/dump_manager.py @@ -6,7 +6,6 @@ import shutil import uuid - class DumpManager(metaclass=Singleton): def __init__(self, dump_dir=None): if dump_dir is None: diff --git a/utils/log.py b/utils/log.py index 64bdc7d..46bfa29 100644 --- a/utils/log.py +++ b/utils/log.py @@ -6,7 +6,6 @@ from datetime import datetime import json - import pandas as pd from bokeh.io import output_file, save, show from bokeh.plotting import figure diff --git a/utils/mark_relu.py b/utils/mark_relu.py index 45dd964..751b335 100644 --- a/utils/mark_relu.py +++ b/utils/mark_relu.py @@ -1,7 +1,6 @@ from torchvision.models.resnet import Bottleneck, BasicBlock from torch.nn.parallel.data_parallel import DataParallel - def mark_bottlenetck_before_relu(model): for m in model.children(): if isinstance(m, Bottleneck): diff --git a/utils/meters.py b/utils/meters.py index 9eebec8..6fd0872 100644 --- a/utils/meters.py +++ b/utils/meters.py @@ -1,9 +1,28 @@ import torch + +class ProgressMeter(object): + def __init__(self, num_batches, *meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def print(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + class AverageMeter(object): """Computes and stores the average and current value""" - - def __init__(self): + def __init__(self, name='', fmt=':f'): + self.name = name + self.fmt = fmt self.reset() def reset(self): @@ -18,6 +37,10 @@ def update(self, val, n=1): self.count += n self.avg = self.sum / self.count + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + class OnlineMeter(object): """Computes and stores the average and variance/std values of tensor""" diff --git a/utils/misc.py b/utils/misc.py index c319713..330d479 100644 --- a/utils/misc.py +++ b/utils/misc.py @@ -75,6 +75,7 @@ def __call__(cls, *args, **kwargs): import re + def sorted_nicely(l): """ Sorts the given iterable in the way that is expected. diff --git a/utils/model_naming.py b/utils/model_naming.py index 1a95b3b..d3a4b1b 100644 --- a/utils/model_naming.py +++ b/utils/model_naming.py @@ -1,5 +1,6 @@ from torch.nn.parallel.data_parallel import DataParallel + def module_type_to_string(m): return (str(type(m)).replace('>', '').replace('\'', '').split('.')[-1]).replace('WithId', '') diff --git a/utils/monitor.py b/utils/monitor.py index 4f06690..fd0142c 100644 --- a/utils/monitor.py +++ b/utils/monitor.py @@ -14,6 +14,7 @@ def __call__(self, *arg, **kwarg): instance.__class__ = _ + class Monitor(metaclass=Singleton): def __init__(self, dump_dir=None): if dump_dir is None: diff --git a/utils/optim.py b/utils/optim.py index cd064dd..28cf7be 100644 --- a/utils/optim.py +++ b/utils/optim.py @@ -9,6 +9,7 @@ def eval_func(f, x): f = eval(f) return f(x) + class OptimRegime(object): """ Reconfigures the optimizer according to setting list. diff --git a/utils/preprocess.py b/utils/preprocess.py index 64e855b..60f4dee 100644 --- a/utils/preprocess.py +++ b/utils/preprocess.py @@ -14,6 +14,7 @@ ]) } + def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): t_list = [ transforms.CenterCrop(input_size),