Skip to content

Commit

Permalink
Remove pt2_quant flag (#3676)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #3676

Reviewed By: dulinriley

Differential Revision: D57491621

fbshipit-source-id: 6a63e239839be950948085e392604c0ffc62e01a
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed May 21, 2024
1 parent 07dcf35 commit a707550
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 22 deletions.
32 changes: 13 additions & 19 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Callable, Tuple
from typing import Any, Tuple

import torch

Expand All @@ -16,25 +16,20 @@


def export_program(
model: Callable,
model: torch.nn.Module,
inputs: Any,
pt2_quant: bool = False,
) -> ExportedProgram:
# we don't support training mode. Make it eval
if hasattr(model, "eval"):
if pt2_quant:
# pyre-fixme[6]: Incompatible parameter type.
torch.ao.quantization.move_exported_model_to_eval(model)
else:
# pyre-fixme[16]: Anonymous callable has no attribute `eval`.
model.eval()

# if it's already an ExportedProgram, just return it
if isinstance(model, ExportedProgram):
return model

assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

# If the model is already a GraphModule (most likely from quantization), call the
# suggested torch.ao.quantization API instead, which only does dropout and batchnorm.
if isinstance(model, torch.fx.GraphModule):
torch.ao.quantization.move_exported_model_to_eval(model)
else:
# We don't support training mode. Make it eval
if hasattr(model, "eval"):
model.eval()

# Prevent mkldnn decompositions
torch._C._set_mkldnn_enabled(False)

Expand All @@ -44,13 +39,12 @@ def export_program(

# Export the model and lower it it edge IR.
def export_to_edge(
model: Callable,
model: torch.nn.Module,
inputs: Any,
pt2_quant: bool = False,
dump_graphs: bool = False,
) -> Tuple[EdgeProgramManager, ExportedProgram]:
# Export the model into an ExportedProgram.
expo_program = export_program(model, inputs, pt2_quant)
expo_program = export_program(model, inputs)

if dump_graphs:
logging.info(f"Exported graph:\n{expo_program.graph_module.graph}")
Expand Down
4 changes: 1 addition & 3 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def export_model(
QuantFusion(patterns)(converted_model)

# Get edge program (note: the name will change to export_to_cadence in future PRs)
edge_prog_manager, expo_prog = export_to_edge(
converted_model, example_inputs, pt2_quant=True
)
edge_prog_manager, expo_prog = export_to_edge(converted_model, example_inputs)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
Expand Down

0 comments on commit a707550

Please sign in to comment.