Skip to content

Commit

Permalink
[Quant] Add FX support in quantization examples
Browse files Browse the repository at this point in the history
Summary: Previously, the quantization examples use only eager
mode quantization. This commit adds support for FX mode
quantization as well. TODO: provide accuracy comparison.

Test Plan:

python train_quantization.py
  --device='cpu'
  --post-training-quantize
  --backend='fbgemm'
  --model='$MODEL'

model: $MODEL is one of googlenet, inception_v3, resnet18, resnet50,
resnext101_32x8d, shufflenet_v2_x0_5 and shufflenet_v2_x1_0

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 50b3f482794f42c742e9e61ff3f4fcbf2b040703
Pull Request resolved: #5797
  • Loading branch information
andrewor14 committed Apr 12, 2022
1 parent 467b841 commit 4afdab5
Showing 1 changed file with 51 additions and 9 deletions.
60 changes: 51 additions & 9 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
import datetime
import os
import time
from enum import Enum

import torch
import torch.ao.quantization
import torch.ao.quantization.quantize_fx
import torch.utils.data
import torchvision
import utils
from torch import nn
from train import train_one_epoch, evaluate, load_data


class QuantizationWorkflowType(Enum):
EAGER_MODE_QUANTIZATION = 1
FX_GRAPH_MODE_QUANTIZATION = 2


def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)
Expand All @@ -22,6 +29,15 @@ def main(args):
if args.post_training_quantize and args.distributed:
raise RuntimeError("Post training quantization example should not be performed on distributed mode")

# Validate quantization workflow type
quantization_workflow_type = args.quantization_workflow_type.upper()
if quantization_workflow_type not in QuantizationWorkflowType.__members__:
raise RuntimeError(
"Unknown workflow type '%s', please choose from: %s"
% (args.quantization_workflow_type, str(tuple([t.lower() for t in QuantizationWorkflowType.__members__])))
)
use_fx_graph_mode_quantization = quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE_QUANTIZATION

# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
Expand All @@ -45,14 +61,23 @@ def main(args):
)

print("Creating model", args.model)
if use_fx_graph_mode_quantization:
model_namespace = torchvision.models
else:
model_namespace = torchvision.models.quantization
# when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model = model_namespace.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device)

if not (args.test_only or args.post_training_quantize):
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True)
qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
if use_fx_graph_mode_quantization:
qconfig_dict = {"": qconfig}
model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
else:
model.fuse_model(is_qat=True)
model.qconfig = qconfig
torch.ao.quantization.prepare_qat(model, inplace=True)

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -84,13 +109,21 @@ def main(args):
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
)
model.eval()
model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True)
qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
if use_fx_graph_mode_quantization:
qconfig_dict = {"": qconfig}
model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict)
else:
model.fuse_model(is_qat=False)
model.qconfig = qconfig
torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first
print("Calibrating")
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.ao.quantization.convert(model, inplace=True)
if use_fx_graph_mode_quantization:
model = torch.ao.quantization.quantize_fx.convert_fx(model)
else:
torch.ao.quantization.convert(model, inplace=True)
if args.output_dir:
print("Saving quantized model")
if utils.is_main_process():
Expand Down Expand Up @@ -125,7 +158,10 @@ def main(args):
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device("cpu"))
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
if use_fx_graph_mode_quantization:
quantized_eval_model = torch.ao.quantization.quantize_fx.convert_fx(quantized_eval_model)
else:
torch.ao.quantization.convert(quantized_eval_model, inplace=True)

print("Evaluate Quantized model")
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
Expand Down Expand Up @@ -233,6 +269,12 @@ def get_args_parser(add_help=True):
help="Post training quantize the model",
action="store_true",
)
parser.add_argument(
"--quantization-workflow-type",
default="eager_mode_quantization",
type=str,
help="The quantization workflow type to use, either 'eager_mode_quantization' (default) or 'fx_graph_mode_quantization'",
)

# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
Expand Down

0 comments on commit 4afdab5

Please sign in to comment.