Skip to content

Commit 5461630

Browse files
committed
up
1 parent 1904fa8 commit 5461630

File tree

4 files changed

+21
-13
lines changed

4 files changed

+21
-13
lines changed

.lintrunner.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ exclude_patterns = [
397397
"extension/llm/export/builder.py",
398398
"extension/llm/export/quantizer_lib.py",
399399
"exir/tests/test_memory_planning.py",
400-
"backends/transforms/duplicate_dynamic_quant_chain.py",
401400
"exir/backend/test/demos/test_xnnpack_qnnpack.py",
402401
]
403402

backends/transforms/duplicate_dynamic_quant_chain.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,12 @@
99

1010
import torch
1111

12-
from torch.ao.quantization.pt2e.utils import (
13-
_filter_sym_size_users,
14-
_is_valid_annotation,
15-
)
16-
1712
from torch.fx.node import map_arg
1813
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1914

15+
from torchao.quantization.pt2e.quantizer import is_valid_annotation
16+
from torchao.quantization.pt2e.utils import _filter_sym_size_users
17+
2018

2119
logger = logging.getLogger(__name__)
2220
logger.setLevel(logging.WARNING)
@@ -129,7 +127,7 @@ def _maybe_duplicate_dynamic_quantize_chain(
129127
dq_node_users = list(dq_node.users.copy())
130128
for user in dq_node_users:
131129
annotation = user.meta.get("quantization_annotation", None)
132-
if not _is_valid_annotation(annotation):
130+
if not is_valid_annotation(annotation):
133131
return
134132
with gm.graph.inserting_after(dq_node):
135133
new_node = gm.graph.node_copy(dq_node)

extension/llm/export/builder.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import contextlib
1414
import logging
1515
from enum import Enum
16-
from typing import Any, Callable, Dict, List, Optional, Tuple
16+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1717
from unittest.mock import patch
1818

1919
import torch
@@ -35,11 +35,15 @@
3535

3636
from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes
3737
from pytorch_tokenizers import get_tokenizer
38-
from torch.ao.quantization.quantizer import Quantizer
39-
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
38+
from torch.ao.quantization.quantizer import TorchQuantizer
39+
from torch.ao.quantization.quantizer.composable_quantizer import (
40+
TorchComposableQuantizer,
41+
)
42+
4043
from torch.export import export_for_training, ExportedProgram
4144
from torch.nn.attention import SDPBackend
4245
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
46+
from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer
4347
from torchao.utils import unwrap_tensor_subclass
4448

4549
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -350,7 +354,9 @@ def calibrate_template(
350354
print(f"{task}: {res}")
351355
logging.info("Calibration finish...")
352356

353-
def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
357+
def pt2e_quantize(
358+
self, quantizers: Optional[List[Union[Quantizer, TorchQuantizer]]]
359+
) -> "LLMEdgeManager":
354360
"""
355361
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
356362
Args:
@@ -367,7 +373,12 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
367373
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
368374
if self.verbose:
369375
logging.info(f"Applied quantizers: {quantizers}")
370-
composed_quantizer = ComposableQuantizer(quantizers)
376+
377+
if any(isinstance(q, Quantizer) for q in quantizers):
378+
composed_quantizer = ComposableQuantizer(quantizers)
379+
else:
380+
composed_quantizer = TorchComposableQuantizer(quantizers)
381+
371382
assert (
372383
self.pre_autograd_graph_module is not None
373384
), "Please run export() first"

extension/llm/export/quantizer_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def get_qnn_quantizer(
154154
QnnQuantizer,
155155
QuantDtype,
156156
)
157-
from torch.ao.quantization.observer import MinMaxObserver
157+
from torchao.quantization.pt2e import MinMaxObserver
158158

159159
except ImportError:
160160
raise ImportError(

0 commit comments

Comments
 (0)