diff --git a/docs/source/models.rst b/docs/source/models.rst index f84d9c7fd1a..8669916ab35 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -486,13 +486,13 @@ Model Acc@1 Acc@5 ================================ ============= ============= MobileNet V2 71.658 90.150 MobileNet V3 Large 73.004 90.858 -ShuffleNet V2 x1.0 68.360 87.582 -ShuffleNet V2 x0.5 57.972 79.780 -ResNet 18 69.494 88.882 -ResNet 50 75.920 92.814 -ResNext 101 32x8d 78.986 94.480 -Inception V3 77.176 93.354 -GoogleNet 69.826 89.404 +ShuffleNet V2 x1.0 68.240 87.524 +ShuffleNet V2 x0.5 57.810 79.724 +ResNet 18 69.500 88.958 +ResNet 50 75.802 92.764 +ResNext 101 32x8d 79.020 94.468 +Inception V3 77.206 93.576 +GoogleNet 69.702 89.388 ================================ ============= ============= diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index c0e5af1dcfc..573461668d0 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -5,10 +5,12 @@ import torch import torch.ao.quantization +import torch.ao.quantization.quantize_fx import torch.utils.data import torchvision import utils from torch import nn +from torchvision.models.quantization.utils import QuantizationWorkflowType from train import train_one_epoch, evaluate, load_data @@ -22,6 +24,15 @@ def main(args): if args.post_training_quantize and args.distributed: raise RuntimeError("Post training quantization example should not be performed on distributed mode") + # Validate quantization workflow type + all_quantization_workflow_types = [t.value for t in QuantizationWorkflowType] + if args.quantization_workflow_type not in all_quantization_workflow_types: + raise RuntimeError( + "Unknown quantization workflow type '%s', must be one of: %s" + % (args.quantization_workflow_type, all_quantization_workflow_types) + ) + quantization_workflow_type = QuantizationWorkflowType(args.quantization_workflow_type) + # Set backend engine to ensure that quantized model runs on the correct kernels if args.backend not in torch.backends.quantized.supported_engines: raise RuntimeError("Quantized backend not supported: " + str(args.backend)) @@ -46,13 +57,21 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model - model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) + model = torchvision.models.quantization.__dict__[args.model]( + weights=args.weights, + quantize=args.test_only, + quantization_workflow_type=quantization_workflow_type, + ) model.to(device) if not (args.test_only or args.post_training_quantize): - model.fuse_model(is_qat=True) - model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend) - torch.ao.quantization.prepare_qat(model, inplace=True) + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + qconfig_dict = torch.ao.quantization.get_default_qat_qconfig_dict(args.backend) + model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict) + else: + model.fuse_model(is_qat=True) + model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend) + torch.ao.quantization.prepare_qat(model, inplace=True) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -84,13 +103,20 @@ def main(args): ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True ) model.eval() - model.fuse_model(is_qat=False) - model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend) - torch.ao.quantization.prepare(model, inplace=True) + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + qconfig_dict = torch.ao.quantization.get_default_qconfig_dict(args.backend) + model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict) + else: + model.fuse_model(is_qat=False) + model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend) + torch.ao.quantization.prepare(model, inplace=True) # Calibrate first print("Calibrating") evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1) - torch.ao.quantization.convert(model, inplace=True) + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + model = torch.ao.quantization.quantize_fx.convert_fx(model) + else: + torch.ao.quantization.convert(model, inplace=True) if args.output_dir: print("Saving quantized model") if utils.is_main_process(): @@ -125,7 +151,10 @@ def main(args): quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model.eval() quantized_eval_model.to(torch.device("cpu")) - torch.ao.quantization.convert(quantized_eval_model, inplace=True) + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + quantized_eval_model = torch.ao.quantization.quantize_fx.convert_fx(quantized_eval_model) + else: + torch.ao.quantization.convert(quantized_eval_model, inplace=True) print("Evaluate Quantized model") evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu")) @@ -233,6 +262,12 @@ def get_args_parser(add_help=True): help="Post training quantize the model", action="store_true", ) + parser.add_argument( + "--quantization-workflow-type", + default="eager_mode", + type=str, + help="The quantization workflow type to use, either 'eager_mode' (default) or 'fx_graph_mode'", + ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 1794c834eea..97d0c8a8b86 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -12,7 +12,7 @@ from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param from ..googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, GoogLeNet_Weights -from .utils import _fuse_modules, _replace_relu, quantize_model +from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType __all__ = [ @@ -170,11 +170,16 @@ def googlenet( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") + quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE) + + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + model = GoogLeNet(**kwargs) + else: + model = QuantizableGoogLeNet(**kwargs) + _replace_relu(model) - model = QuantizableGoogLeNet(**kwargs) - _replace_relu(model) if quantize: - quantize_model(model, backend) + model = quantize_model(model, backend, quantization_workflow_type) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index ff5c9a37365..c8eeac4c8ff 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -7,13 +7,13 @@ import torch.nn.functional as F from torch import Tensor from torchvision.models import inception as inception_module -from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights +from torchvision.models.inception import Inception3, InceptionOutputs, Inception_V3_Weights from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param -from .utils import _fuse_modules, _replace_relu, quantize_model +from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType __all__ = [ @@ -239,11 +239,16 @@ def inception_v3( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") + quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE) + + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + model = Inception3(**kwargs) + else: + model = QuantizableInception3(**kwargs) + _replace_relu(model) - model = QuantizableInception3(**kwargs) - _replace_relu(model) if quantize: - quantize_model(model, backend) + model = quantize_model(model, backend, quantization_workflow_type) if weights is not None: if quantize and not original_aux_logits: diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index d9554e0ba9f..422ab6c3c43 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -11,7 +11,7 @@ from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param -from .utils import _fuse_modules, _replace_relu, quantize_model +from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType __all__ = [ @@ -125,11 +125,16 @@ def mobilenet_v2( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "qnnpack") + quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE) + + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + model = MobileNetV2(**kwargs) + else: + model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) + _replace_relu(model) - model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) - _replace_relu(model) if quantize: - quantize_model(model, backend) + model = quantize_model(model, backend, quantization_workflow_type) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 88907ec210a..203f1c3ad97 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -2,6 +2,7 @@ from typing import Any, List, Optional, Union import torch +import torch.ao.quantization.quantize_fx from torch import nn, Tensor from torch.ao.quantization import QuantStub, DeQuantStub @@ -17,7 +18,7 @@ _mobilenet_v3_conf, MobileNet_V3_Large_Weights, ) -from .utils import _fuse_modules, _replace_relu +from .utils import _fuse_modules, _replace_relu, QuantizationWorkflowType __all__ = [ @@ -135,20 +136,32 @@ def _mobilenet_v3_model( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "qnnpack") + quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE) - model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) - _replace_relu(model) + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) + else: + model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) + _replace_relu(model) if quantize: - model.fuse_model(is_qat=True) - model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) - torch.ao.quantization.prepare_qat(model, inplace=True) + # TODO: This shouldn't be QAT? + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + qconfig_dict = torch.ao.quantization.get_default_qat_qconfig_dict(backend) + model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict) + else: + model.fuse_model(is_qat=True) + model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) + torch.ao.quantization.prepare_qat(model, inplace=True) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) if quantize: - torch.ao.quantization.convert(model, inplace=True) + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + model = torch.ao.quantization.quantize_fx.convert_fx(model) + else: + torch.ao.quantization.convert(model, inplace=True) model.eval() return model diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index a781f320000..6e42fda9b12 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -17,7 +17,7 @@ from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param -from .utils import _fuse_modules, _replace_relu, quantize_model +from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType __all__ = [ @@ -134,11 +134,16 @@ def _resnet( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") + quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE) + + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + model = ResNet(block, layers, **kwargs) + else: + model = QuantizableResNet(block, layers, **kwargs) + _replace_relu(model) - model = QuantizableResNet(block, layers, **kwargs) - _replace_relu(model) if quantize: - quantize_model(model, backend) + model = quantize_model(model, backend, quantization_workflow_type) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 1f4f1890e07..809b7104b43 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -10,8 +10,8 @@ from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights -from .utils import _fuse_modules, _replace_relu, quantize_model +from ..shufflenetv2 import ShuffleNetV2, ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights +from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType __all__ = [ @@ -40,7 +40,7 @@ def forward(self, x: Tensor) -> Tensor: return out -class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): +class QuantizableShuffleNetV2(ShuffleNetV2): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) # type: ignore[misc] @@ -89,11 +89,16 @@ def _shufflenetv2( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") + quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE) + + if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + model = ShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) + else: + model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) + _replace_relu(model) - model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) - _replace_relu(model) if quantize: - quantize_model(model, backend) + model = quantize_model(model, backend, quantization_workflow_type) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) diff --git a/torchvision/models/quantization/utils.py b/torchvision/models/quantization/utils.py index a21e2af8e01..8f1356dd504 100644 --- a/torchvision/models/quantization/utils.py +++ b/torchvision/models/quantization/utils.py @@ -1,9 +1,16 @@ +from enum import Enum from typing import Any, List, Optional, Union import torch +import torch.ao.quantization.quantize_fx from torch import nn +class QuantizationWorkflowType(Enum): + EAGER_MODE = "eager_mode" + FX_GRAPH_MODE = "fx_graph_mode" + + def _replace_relu(module: nn.Module) -> None: reassign = {} for name, mod in module.named_children(): @@ -18,28 +25,37 @@ def _replace_relu(module: nn.Module) -> None: module._modules[key] = value -def quantize_model(model: nn.Module, backend: str) -> None: +def quantize_model(model: nn.Module, backend: str, quantization_workflow_type: QuantizationWorkflowType) -> nn.Module: _dummy_input_data = torch.rand(1, 3, 299, 299) if backend not in torch.backends.quantized.supported_engines: raise RuntimeError("Quantized backend not supported ") torch.backends.quantized.engine = backend model.eval() - # Make sure that weight qconfig matches that of the serialized models - if backend == "fbgemm": - model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] - activation=torch.ao.quantization.default_observer, - weight=torch.ao.quantization.default_per_channel_weight_observer, - ) - elif backend == "qnnpack": - model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] - activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer - ) - - # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 - model.fuse_model() # type: ignore[operator] - torch.ao.quantization.prepare(model, inplace=True) - model(_dummy_input_data) - torch.ao.quantization.convert(model, inplace=True) + if quantization_workflow_type == QuantizationWorkflowType.EAGER_MODE: + # Make sure that weight qconfig matches that of the serialized models + if backend == "fbgemm": + model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] + activation=torch.ao.quantization.default_observer, + weight=torch.ao.quantization.default_per_channel_weight_observer, + ) + elif backend == "qnnpack": + model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] + activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer + ) + + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + model.fuse_model() # type: ignore[operator] + torch.ao.quantization.prepare(model, inplace=True) + model(_dummy_input_data) + torch.ao.quantization.convert(model, inplace=True) + elif quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE: + qconfig_dict = torch.ao.quantization.get_default_qconfig_dict(backend) + model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict) + model(_dummy_input_data) + model = torch.ao.quantization.quantize_fx.convert_fx(model) + else: + raise ValueError("Unknown quantization workflow type '%s'" % quantization_workflow_type) + return model def _fuse_modules(