Skip to content

Commit

Permalink
[Quant] Add FX support in quantization examples
Browse files Browse the repository at this point in the history
Summary: Previously, the quantization examples use only eager
mode quantization. This commit adds support for FX mode
quantization as well.

Test Plan:

```

python train_quantization.py --device="cpu" --post-training-quantize --backend="fbgemm"\
  --model="$MODEL" --weights="IMAGENET1K_V1" --quantization-workflow-type="eager_mode_quantization"

python train_quantization.py --device="cpu" --post-training-quantize --backend="fbgemm"\
  --model="$MODEL" --weights="IMAGENET1K_V1" --quantization-workflow-type="eager_mode_quantization"

python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v2"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.0001 --weight-decay=0.0001\
  --quantization-workflow-type="eager_mode_quantization"

python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v2"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.0001 --weight-decay=0.0001\
  --quantization-workflow-type="fx_graph_mode_quantization"

python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v3_large"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.001 --weight-decay=0.00001\
  --quantization-workflow-type="eager_mode_quantization"

python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v3_large"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.001 --weight-decay=0.00001\
  --quantization-workflow-type="fx_graph_mode_quantization"
```

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 6308d7bf03516d6aefea0d985de8e5486d8751ce
Pull Request resolved: #5797
  • Loading branch information
andrewor14 committed Apr 21, 2022
1 parent 3122ea1 commit a68bd2d
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 62 deletions.
12 changes: 6 additions & 6 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,13 @@ Model Acc@1 Acc@5
================================ ============= =============
MobileNet V2 71.658 90.150
MobileNet V3 Large 73.004 90.858
ShuffleNet V2 x1.0 68.360 87.582
ShuffleNet V2 x0.5 57.972 79.780
ResNet 18 69.494 88.882
ResNet 50 75.920 92.814
ResNext 101 32x8d 78.986 94.480
ShuffleNet V2 x1.0 68.240 87.524
ShuffleNet V2 x0.5 57.810 79.724
ResNet 18 69.500 88.958
ResNet 50 75.802 92.764
ResNext 101 32x8d 79.020 94.468
Inception V3 77.176 93.354
GoogleNet 69.826 89.404
GoogleNet 69.702 89.388
================================ ============= =============


Expand Down
53 changes: 44 additions & 9 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

import torch
import torch.ao.quantization
import torch.ao.quantization.quantize_fx
import torch.utils.data
import torchvision
import utils
from torch import nn
from torchvision.models.quantization.utils import QuantizationWorkflowType
from train import train_one_epoch, evaluate, load_data


Expand All @@ -22,6 +24,15 @@ def main(args):
if args.post_training_quantize and args.distributed:
raise RuntimeError("Post training quantization example should not be performed on distributed mode")

# Validate quantization workflow type
all_quantization_workflow_types = [t.value for t in QuantizationWorkflowType]
if args.quantization_workflow_type not in all_quantization_workflow_types:
raise RuntimeError(
"Unknown quantization workflow type '%s', must be one of: %s"
% (args.quantization_workflow_type, all_quantization_workflow_types)
)
quantization_workflow_type = QuantizationWorkflowType(args.quantization_workflow_type)

# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
Expand All @@ -46,13 +57,21 @@ def main(args):

print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model = torchvision.models.quantization.__dict__[args.model](
weights=args.weights,
quantize=args.test_only,
quantization_workflow_type=quantization_workflow_type,
)
model.to(device)

if not (args.test_only or args.post_training_quantize):
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True)
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
qconfig_dict = torch.ao.quantization.get_default_qat_qconfig_dict(args.backend)
model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
else:
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True)

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -84,13 +103,20 @@ def main(args):
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
)
model.eval()
model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True)
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
qconfig_dict = torch.ao.quantization.get_default_qconfig_dict(args.backend)
model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict)
else:
model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first
print("Calibrating")
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.ao.quantization.convert(model, inplace=True)
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
model = torch.ao.quantization.quantize_fx.convert_fx(model)
else:
torch.ao.quantization.convert(model, inplace=True)
if args.output_dir:
print("Saving quantized model")
if utils.is_main_process():
Expand Down Expand Up @@ -125,7 +151,10 @@ def main(args):
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device("cpu"))
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
quantized_eval_model = torch.ao.quantization.quantize_fx.convert_fx(quantized_eval_model)
else:
torch.ao.quantization.convert(quantized_eval_model, inplace=True)

print("Evaluate Quantized model")
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
Expand Down Expand Up @@ -233,6 +262,12 @@ def get_args_parser(add_help=True):
help="Post training quantize the model",
action="store_true",
)
parser.add_argument(
"--quantization-workflow-type",
default="eager_mode",
type=str,
help="The quantization workflow type to use, either 'eager_mode' (default) or 'fx_graph_mode'",
)

# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
Expand Down
13 changes: 9 additions & 4 deletions torchvision/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, GoogLeNet_Weights
from .utils import _fuse_modules, _replace_relu, quantize_model
from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType


__all__ = [
Expand Down Expand Up @@ -170,11 +170,16 @@ def googlenet(
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)

if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
model = GoogLeNet(**kwargs)
else:
model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)

model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
model = quantize_model(model, backend, quantization_workflow_type)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
Expand Down
15 changes: 10 additions & 5 deletions torchvision/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import torch.nn.functional as F
from torch import Tensor
from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights
from torchvision.models.inception import Inception3, InceptionOutputs, Inception_V3_Weights

from ...transforms._presets import ImageClassification, InterpolationMode
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from .utils import _fuse_modules, _replace_relu, quantize_model
from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType


__all__ = [
Expand Down Expand Up @@ -239,11 +239,16 @@ def inception_v3(
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)

if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
model = Inception3(**kwargs)
else:
model = QuantizableInception3(**kwargs)
_replace_relu(model)

model = QuantizableInception3(**kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
model = quantize_model(model, backend, quantization_workflow_type)

if weights is not None:
if quantize and not original_aux_logits:
Expand Down
13 changes: 9 additions & 4 deletions torchvision/models/quantization/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from .utils import _fuse_modules, _replace_relu, quantize_model
from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType


__all__ = [
Expand Down Expand Up @@ -125,11 +125,16 @@ def mobilenet_v2(
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack")
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)

if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
model = MobileNetV2(**kwargs)
else:
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)

model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
model = quantize_model(model, backend, quantization_workflow_type)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
Expand Down
27 changes: 20 additions & 7 deletions torchvision/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, List, Optional, Union

import torch
import torch.ao.quantization.quantize_fx
from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub

Expand All @@ -17,7 +18,7 @@
_mobilenet_v3_conf,
MobileNet_V3_Large_Weights,
)
from .utils import _fuse_modules, _replace_relu
from .utils import _fuse_modules, _replace_relu, QuantizationWorkflowType


__all__ = [
Expand Down Expand Up @@ -135,20 +136,32 @@ def _mobilenet_v3_model(
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack")
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)

model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
else:
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)

if quantize:
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
torch.ao.quantization.prepare_qat(model, inplace=True)
# TODO: This shouldn't be QAT?
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
qconfig_dict = torch.ao.quantization.get_default_qat_qconfig_dict(backend)
model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
else:
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
torch.ao.quantization.prepare_qat(model, inplace=True)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))

if quantize:
torch.ao.quantization.convert(model, inplace=True)
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
model = torch.ao.quantization.quantize_fx.convert_fx(model)
else:
torch.ao.quantization.convert(model, inplace=True)
model.eval()

return model
Expand Down
13 changes: 9 additions & 4 deletions torchvision/models/quantization/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from .utils import _fuse_modules, _replace_relu, quantize_model
from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType


__all__ = [
Expand Down Expand Up @@ -134,11 +134,16 @@ def _resnet(
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)

if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
model = ResNet(block, layers, **kwargs)
else:
model = QuantizableResNet(block, layers, **kwargs)
_replace_relu(model)

model = QuantizableResNet(block, layers, **kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
model = quantize_model(model, backend, quantization_workflow_type)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
Expand Down
17 changes: 11 additions & 6 deletions torchvision/models/quantization/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
from .utils import _fuse_modules, _replace_relu, quantize_model
from ..shufflenetv2 import ShuffleNetV2, ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType


__all__ = [
Expand Down Expand Up @@ -40,7 +40,7 @@ def forward(self, x: Tensor) -> Tensor:
return out


class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
class QuantizableShuffleNetV2(ShuffleNetV2):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) # type: ignore[misc]
Expand Down Expand Up @@ -89,11 +89,16 @@ def _shufflenetv2(
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)

if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
model = ShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
else:
model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
_replace_relu(model)

model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
model = quantize_model(model, backend, quantization_workflow_type)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
Expand Down
Loading

0 comments on commit a68bd2d

Please sign in to comment.