diff --git a/docs/source/models.rst b/docs/source/models.rst index f84d9c7fd1a..d2704b5fadb 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -486,13 +486,13 @@ Model Acc@1 Acc@5 ================================ ============= ============= MobileNet V2 71.658 90.150 MobileNet V3 Large 73.004 90.858 -ShuffleNet V2 x1.0 68.360 87.582 -ShuffleNet V2 x0.5 57.972 79.780 -ResNet 18 69.494 88.882 -ResNet 50 75.920 92.814 -ResNext 101 32x8d 78.986 94.480 +ShuffleNet V2 x1.0 67.886 87.332 +ShuffleNet V2 x0.5 57.784 79.458 +ResNet 18 69.458 88.902 +ResNet 50 75.712 92.782 +ResNext 101 32x8d 78.982 94.422 Inception V3 77.176 93.354 -GoogleNet 69.826 89.404 +GoogleNet 69.598 89.398 ================================ ============= ============= diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index c0e5af1dcfc..e2d6b8ec99c 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -2,9 +2,11 @@ import datetime import os import time +from enum import Enum import torch import torch.ao.quantization +import torch.ao.quantization.quantize_fx import torch.utils.data import torchvision import utils @@ -12,6 +14,11 @@ from train import train_one_epoch, evaluate, load_data +class QuantizationWorkflowType(Enum): + EAGER_MODE_QUANTIZATION = 1 + FX_GRAPH_MODE_QUANTIZATION = 2 + + def main(args): if args.output_dir: utils.mkdir(args.output_dir) @@ -22,6 +29,17 @@ def main(args): if args.post_training_quantize and args.distributed: raise RuntimeError("Post training quantization example should not be performed on distributed mode") + # Validate quantization workflow type + quantization_workflow_type = args.quantization_workflow_type.upper() + if quantization_workflow_type not in QuantizationWorkflowType.__members__: + raise RuntimeError( + "Unknown workflow type '%s', please choose from: %s" + % (args.quantization_workflow_type, str(tuple([t.lower() for t in QuantizationWorkflowType.__members__]))) + ) + use_fx_graph_mode_quantization = ( + QuantizationWorkflowType[quantization_workflow_type] == QuantizationWorkflowType.FX_GRAPH_MODE_QUANTIZATION + ) + # Set backend engine to ensure that quantized model runs on the correct kernels if args.backend not in torch.backends.quantized.supported_engines: raise RuntimeError("Quantized backend not supported: " + str(args.backend)) @@ -46,13 +64,20 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model - model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) + if use_fx_graph_mode_quantization: + model = torchvision.models.__dict__[args.model](weights=args.weights) + else: + model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) model.to(device) if not (args.test_only or args.post_training_quantize): - model.fuse_model(is_qat=True) - model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend) - torch.ao.quantization.prepare_qat(model, inplace=True) + if use_fx_graph_mode_quantization: + qconfig_dict = torch.ao.quantization.get_default_qat_qconfig_dict(args.backend) + model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict) + else: + model.fuse_model(is_qat=True) + model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend) + torch.ao.quantization.prepare_qat(model, inplace=True) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -84,13 +109,20 @@ def main(args): ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True ) model.eval() - model.fuse_model(is_qat=False) - model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend) - torch.ao.quantization.prepare(model, inplace=True) + if use_fx_graph_mode_quantization: + qconfig_dict = torch.ao.quantization.get_default_qconfig_dict(args.backend) + model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict) + else: + model.fuse_model(is_qat=False) + model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend) + torch.ao.quantization.prepare(model, inplace=True) # Calibrate first print("Calibrating") evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1) - torch.ao.quantization.convert(model, inplace=True) + if use_fx_graph_mode_quantization: + model = torch.ao.quantization.quantize_fx.convert_fx(model) + else: + torch.ao.quantization.convert(model, inplace=True) if args.output_dir: print("Saving quantized model") if utils.is_main_process(): @@ -125,7 +157,10 @@ def main(args): quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model.eval() quantized_eval_model.to(torch.device("cpu")) - torch.ao.quantization.convert(quantized_eval_model, inplace=True) + if use_fx_graph_mode_quantization: + quantized_eval_model = torch.ao.quantization.quantize_fx.convert_fx(quantized_eval_model) + else: + torch.ao.quantization.convert(quantized_eval_model, inplace=True) print("Evaluate Quantized model") evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu")) @@ -233,6 +268,12 @@ def get_args_parser(add_help=True): help="Post training quantize the model", action="store_true", ) + parser.add_argument( + "--quantization-workflow-type", + default="eager_mode_quantization", + type=str, + help="The quantization workflow type to use, either 'eager_mode_quantization' (default) or 'fx_graph_mode_quantization'", + ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")