Skip to content

Commit

Permalink
[Quant] Add FX support in quantization examples
Browse files Browse the repository at this point in the history
Summary: Previously, the quantization examples use only eager
mode quantization. This commit adds support for FX mode
quantization as well.

Test Plan:

```

python train_quantization.py --device="cpu" --post-training-quantize --backend="fbgemm"\
  --model="$MODEL" --weights="IMAGENET1K_V1" --quantization-workflow-type="eager_mode_quantization"

python train_quantization.py --device="cpu" --post-training-quantize --backend="fbgemm"\
  --model="$MODEL" --weights="IMAGENET1K_V1" --quantization-workflow-type="eager_mode_quantization"

python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v2"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.0001 --weight-decay=0.0001\
  --quantization-workflow-type="eager_mode_quantization"

python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v2"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.0001 --weight-decay=0.0001\
  --quantization-workflow-type="fx_graph_mode_quantization"

python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v3_large"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.001 --weight-decay=0.00001\
  --quantization-workflow-type="eager_mode_quantization"

python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v3_large"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.001 --weight-decay=0.00001\
  --quantization-workflow-type="fx_graph_mode_quantization"
```

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 6308d7bf03516d6aefea0d985de8e5486d8751ce
Pull Request resolved: #5797
  • Loading branch information
andrewor14 committed Apr 14, 2022
1 parent 467b841 commit 21c8623
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 15 deletions.
12 changes: 6 additions & 6 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,13 @@ Model Acc@1 Acc@5
================================ ============= =============
MobileNet V2 71.658 90.150
MobileNet V3 Large 73.004 90.858
ShuffleNet V2 x1.0 68.360 87.582
ShuffleNet V2 x0.5 57.972 79.780
ResNet 18 69.494 88.882
ResNet 50 75.920 92.814
ResNext 101 32x8d 78.986 94.480
ShuffleNet V2 x1.0 67.886 87.332
ShuffleNet V2 x0.5 57.784 79.458
ResNet 18 69.458 88.902
ResNet 50 75.712 92.782
ResNext 101 32x8d 78.982 94.422
Inception V3 77.176 93.354
GoogleNet 69.826 89.404
GoogleNet 69.598 89.398
================================ ============= =============


Expand Down
59 changes: 50 additions & 9 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
import datetime
import os
import time
from enum import Enum

import torch
import torch.ao.quantization
import torch.ao.quantization.quantize_fx
import torch.utils.data
import torchvision
import utils
from torch import nn
from train import train_one_epoch, evaluate, load_data


class QuantizationWorkflowType(Enum):
EAGER_MODE_QUANTIZATION = 1
FX_GRAPH_MODE_QUANTIZATION = 2


def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)
Expand All @@ -22,6 +29,17 @@ def main(args):
if args.post_training_quantize and args.distributed:
raise RuntimeError("Post training quantization example should not be performed on distributed mode")

# Validate quantization workflow type
quantization_workflow_type = args.quantization_workflow_type.upper()
if quantization_workflow_type not in QuantizationWorkflowType.__members__:
raise RuntimeError(
"Unknown workflow type '%s', please choose from: %s"
% (args.quantization_workflow_type, str(tuple([t.lower() for t in QuantizationWorkflowType.__members__])))
)
use_fx_graph_mode_quantization = (
QuantizationWorkflowType[quantization_workflow_type] == QuantizationWorkflowType.FX_GRAPH_MODE_QUANTIZATION
)

# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
Expand All @@ -46,13 +64,20 @@ def main(args):

print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
if use_fx_graph_mode_quantization:
model = torchvision.models.__dict__[args.model](weights=args.weights)
else:
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device)

if not (args.test_only or args.post_training_quantize):
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True)
if use_fx_graph_mode_quantization:
qconfig_dict = torch.ao.quantization.get_default_qat_qconfig_dict(args.backend)
model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
else:
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True)

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -84,13 +109,20 @@ def main(args):
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
)
model.eval()
model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True)
if use_fx_graph_mode_quantization:
qconfig_dict = torch.ao.quantization.get_default_qconfig_dict(args.backend)
model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict)
else:
model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first
print("Calibrating")
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.ao.quantization.convert(model, inplace=True)
if use_fx_graph_mode_quantization:
model = torch.ao.quantization.quantize_fx.convert_fx(model)
else:
torch.ao.quantization.convert(model, inplace=True)
if args.output_dir:
print("Saving quantized model")
if utils.is_main_process():
Expand Down Expand Up @@ -125,7 +157,10 @@ def main(args):
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device("cpu"))
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
if use_fx_graph_mode_quantization:
quantized_eval_model = torch.ao.quantization.quantize_fx.convert_fx(quantized_eval_model)
else:
torch.ao.quantization.convert(quantized_eval_model, inplace=True)

print("Evaluate Quantized model")
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
Expand Down Expand Up @@ -233,6 +268,12 @@ def get_args_parser(add_help=True):
help="Post training quantize the model",
action="store_true",
)
parser.add_argument(
"--quantization-workflow-type",
default="eager_mode_quantization",
type=str,
help="The quantization workflow type to use, either 'eager_mode_quantization' (default) or 'fx_graph_mode_quantization'",
)

# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
Expand Down

0 comments on commit 21c8623

Please sign in to comment.