Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quant] Add FX support in quantization examples #5797

Closed
wants to merge 8 commits into from
Prev Previous commit
Next Next commit
Update on "[Quant] Add FX support in quantization examples"
Summary: Previously, the quantization examples use only eager
mode quantization. This commit adds support for FX mode
quantization as well.

Test Plan:

```
# ==================== PTQ ====================
# MODEL is one of googlenet, inception_v3, resnet18, resnet50, resnext101_32x8d,
# shufflenet_v2_x0_5, shufflenet_v2_x1_0

# eager
python train_quantization.py --device="cpu" --post-training-quantize --backend="fbgemm"\
  --model="$MODEL" --weights="IMAGENET1K_V1" --quantization-workflow-type="eager_mode_quantization"

# fx
python train_quantization.py --device="cpu" --post-training-quantize --backend="fbgemm"\
  --model="$MODEL" --weights="IMAGENET1K_V1" --quantization-workflow-type="eager_mode_quantization"

# ==================== QAT ====================
# mobilenet_v2 eager
python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v2"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.0001 --weight-decay=0.0001\
  --quantization-workflow-type="eager_mode_quantization"

# mobilenet_v2 fx
python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v2"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.0001 --weight-decay=0.0001\
  --quantization-workflow-type="fx_graph_mode_quantization"

# mobilenet_v3_large eager
python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v3_large"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.001 --weight-decay=0.00001\
  --quantization-workflow-type="eager_mode_quantization"

# mobilenet_v3_large fx
python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v3_large"\
  --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.001 --weight-decay=0.00001\
  --quantization-workflow-type="fx_graph_mode_quantization"
```

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

[ghstack-poisoned]
  • Loading branch information
andrewor14 committed Apr 14, 2022
commit a84092abfb223933bfd6ddac77885296d0a1d176
11 changes: 6 additions & 5 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def main(args):
"Unknown workflow type '%s', please choose from: %s"
% (args.quantization_workflow_type, str(tuple([t.lower() for t in QuantizationWorkflowType.__members__])))
)
use_fx_graph_mode_quantization = quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE_QUANTIZATION
use_fx_graph_mode_quantization = (
andrewor14 marked this conversation as resolved.
Show resolved Hide resolved
QuantizationWorkflowType[quantization_workflow_type] == QuantizationWorkflowType.FX_GRAPH_MODE_QUANTIZATION
)

# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
Expand All @@ -61,12 +63,11 @@ def main(args):
)

print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
if use_fx_graph_mode_quantization:
model_namespace = torchvision.models
model = torchvision.models.__dict__[args.model](weights=args.weights)
andrewor14 marked this conversation as resolved.
Show resolved Hide resolved
else:
model_namespace = torchvision.models.quantization
# when training quantized models, we always start from a pre-trained fp32 reference model
model = model_namespace.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device)

if not (args.test_only or args.post_training_quantize):
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.