Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quant] Add FX support in quantization examples #5797

Closed
wants to merge 8 commits into from
12 changes: 6 additions & 6 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,13 @@ Model Acc@1 Acc@5
================================ ============= =============
MobileNet V2 71.658 90.150
MobileNet V3 Large 73.004 90.858
ShuffleNet V2 x1.0 68.360 87.582
ShuffleNet V2 x0.5 57.972 79.780
ResNet 18 69.494 88.882
ResNet 50 75.920 92.814
ResNext 101 32x8d 78.986 94.480
ShuffleNet V2 x1.0 67.886 87.332
andrewor14 marked this conversation as resolved.
Show resolved Hide resolved
ShuffleNet V2 x0.5 57.784 79.458
ResNet 18 69.458 88.902
ResNet 50 75.712 92.782
ResNext 101 32x8d 78.982 94.422
Inception V3 77.176 93.354
GoogleNet 69.826 89.404
GoogleNet 69.598 89.398
================================ ============= =============


Expand Down
59 changes: 50 additions & 9 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
import datetime
import os
import time
from enum import Enum

import torch
import torch.ao.quantization
import torch.ao.quantization.quantize_fx
import torch.utils.data
import torchvision
import utils
from torch import nn
from train import train_one_epoch, evaluate, load_data


class QuantizationWorkflowType(Enum):
EAGER_MODE_QUANTIZATION = 1
FX_GRAPH_MODE_QUANTIZATION = 2


def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)
Expand All @@ -22,6 +29,17 @@ def main(args):
if args.post_training_quantize and args.distributed:
raise RuntimeError("Post training quantization example should not be performed on distributed mode")

# Validate quantization workflow type
quantization_workflow_type = args.quantization_workflow_type.upper()
if quantization_workflow_type not in QuantizationWorkflowType.__members__:
raise RuntimeError(
"Unknown workflow type '%s', please choose from: %s"
% (args.quantization_workflow_type, str(tuple([t.lower() for t in QuantizationWorkflowType.__members__])))
)
use_fx_graph_mode_quantization = (
andrewor14 marked this conversation as resolved.
Show resolved Hide resolved
QuantizationWorkflowType[quantization_workflow_type] == QuantizationWorkflowType.FX_GRAPH_MODE_QUANTIZATION
)

# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
Expand All @@ -46,13 +64,20 @@ def main(args):

print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
if use_fx_graph_mode_quantization:
model = torchvision.models.__dict__[args.model](weights=args.weights)
andrewor14 marked this conversation as resolved.
Show resolved Hide resolved
else:
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device)

if not (args.test_only or args.post_training_quantize):
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True)
if use_fx_graph_mode_quantization:
qconfig_dict = torch.ao.quantization.get_default_qat_qconfig_dict(args.backend)
model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
else:
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True)

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -84,13 +109,20 @@ def main(args):
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
)
model.eval()
model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True)
if use_fx_graph_mode_quantization:
qconfig_dict = torch.ao.quantization.get_default_qconfig_dict(args.backend)
model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict)
else:
model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first
print("Calibrating")
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.ao.quantization.convert(model, inplace=True)
if use_fx_graph_mode_quantization:
model = torch.ao.quantization.quantize_fx.convert_fx(model)
else:
torch.ao.quantization.convert(model, inplace=True)
andrewor14 marked this conversation as resolved.
Show resolved Hide resolved
if args.output_dir:
print("Saving quantized model")
if utils.is_main_process():
Expand Down Expand Up @@ -125,7 +157,10 @@ def main(args):
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device("cpu"))
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
if use_fx_graph_mode_quantization:
quantized_eval_model = torch.ao.quantization.quantize_fx.convert_fx(quantized_eval_model)
else:
torch.ao.quantization.convert(quantized_eval_model, inplace=True)

print("Evaluate Quantized model")
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
Expand Down Expand Up @@ -233,6 +268,12 @@ def get_args_parser(add_help=True):
help="Post training quantize the model",
action="store_true",
)
parser.add_argument(
"--quantization-workflow-type",
andrewor14 marked this conversation as resolved.
Show resolved Hide resolved
default="eager_mode_quantization",
type=str,
help="The quantization workflow type to use, either 'eager_mode_quantization' (default) or 'fx_graph_mode_quantization'",
)

# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
Expand Down