Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMP Misc enhancements #3692

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Always use input quantizer at idx=0 for Concat
Signed-off-by: yathindra kota <quic_ykota@quicinc.com>
  • Loading branch information
quic-ykota committed Dec 20, 2024
commit bbba7f6beb55f9a51b1224b183944f61b817e79a
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@
import copy
from typing import Dict, List, Tuple, Optional, Union, IO

import torch.nn

from aimet_common.defs import QuantizationDataType, QuantScheme
from aimet_common.utils import AimetLogger
from aimet_torch.v2.nn.modules.custom import QuantizedConcat
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.nn import BaseQuantizationMixin
Expand Down Expand Up @@ -429,6 +432,23 @@ def _propagate_request_upstream_helper(module):
_propagate_request_upstream_helper(module)
return mp_requests

def _get_child_module_and_idx(self, module: torch.nn.Module):
"""
Helper to get the child module and their input idxes consistent with QuantSim interpretation

:param module: module to return the child modules and their idxes
"""
child_module_idxs = self.cg_traverser.get_child_module_at_output(module)
# Even if concat op has more than one input, in QuantSim there is only one quantizer added.
# This check always returns idx=0 for those modules
updated_child_module_idxs = []
for child_module, input_idx in child_module_idxs:
if isinstance(child_module, QuantizedConcat):
input_idx = 0
updated_child_module_idxs.append((child_module, input_idx))
return updated_child_module_idxs


def _resolve_request_outputs(self, mp_requests, log_file: IO):
"""
Determine if output candidates from request at the provided module should be applied or discarded
Expand All @@ -442,7 +462,7 @@ def _resolve_request_outputs_helper(module):
return

# If the output request at this module came from a downstream consumer, return without changing candidate
child_modules_and_idxs = self.cg_traverser.get_child_module_at_output(module)
child_modules_and_idxs = self._get_child_module_and_idx(module)
for child_module, input_idx in child_modules_and_idxs:
child_request = mp_requests.get(child_module)
if child_request and child_request.input_candidates and \
Expand Down Expand Up @@ -516,9 +536,9 @@ def _apply_new_request_for_module(module, request) -> bool:
# module does not have a request. Create a new one based on the request input
self._update_request_at_module(mp_requests,
module,
request.input_candidates[0] if request.output_candidates else None,
request.input_candidates[0] if request.output_candidates and len(request.output_candidates) > 0 else None,
copy.deepcopy(request.param_candidate) if len(module.param_quantizers.keys()) else None,
request.output_candidates[0] if request.output_candidates else None,
request.output_candidates[0] if request.output_candidates and len(request.output_candidates) > 0 else None,
strict=strict)
mp_requests[module].id = request.id

Expand Down Expand Up @@ -556,7 +576,7 @@ def _apply_new_request_for_module(module, request) -> bool:

# if parent has output quantizer, propagate this request to all other children
if any(parent_module.output_quantizers):
child_modules_and_idxs = self.cg_traverser.get_child_module_at_output(parent_module)
child_modules_and_idxs = self._get_child_module_and_idx(parent_module)
for child_module, _ in child_modules_and_idxs:
new_request = MpRequest(id=mp_request.id,
input_candidates=[input_candidate] * len(child_module.input_quantizers),
Expand All @@ -576,7 +596,7 @@ def _apply_new_request_for_module(module, request) -> bool:
if mp_request.output_candidates:
# resolve at output using output candidate, if the module has output quantizer, then no need to resolve at output
if not any(current_module.output_quantizers):
child_modules_and_idxs = self.cg_traverser.get_child_module_at_output(current_module)
child_modules_and_idxs = self._get_child_module_and_idx(current_module)
for child_module, _ in child_modules_and_idxs:
new_request = MpRequest(id=mp_request.id,
input_candidates=mp_request.output_candidates * len(child_module.input_quantizers),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from torch import nn

from aimet_common.defs import QuantizationDataType
from aimet_torch.v2.nn import BaseQuantizationMixin
from aimet_torch.v2.quantization.base.quantizer import QuantizerBase
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.mixed_precision import MixedPrecisionConfigurator, SupportedDType, Precision
Expand Down Expand Up @@ -1428,3 +1429,43 @@ def test_mp_43(self, test_pass_scenario):
with pytest.raises(RuntimeError):
with open(os.path.join(tmp_dir, './mmp_log.txt'), 'w') as f:
mp_configurator.apply(f, strict=True)


def test_mp_44(self):
"""
For concat op, there is always a single input quantizer added irrespective of the number of inputs. This test
validates this scenario is handled correctly
"""
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)
self.concat = aimet_elementwise.Concat()
self.fc3 = nn.Linear(10, 10)

def forward(self, *inputs):
x1 = self.fc1(inputs[0])
x2 = self.fc2(inputs[0])
x = self.concat(x1, x2)
return self.fc3(x)

model = TestModel()
input_shape = (5, 10)

torch.manual_seed(0)
input_tensor = torch.randn(*input_shape)

sim = QuantizationSimModel(model, input_tensor)
mp_configurator = MixedPrecisionConfigurator(sim)
mp_configurator.set_precision(sim.model.fc1, activation='int16', param={'weight': 'int16'})
mp_configurator.set_precision(sim.model.fc2, activation='int16', param={'weight': 'int16'})
mp_configurator.set_precision(sim.model.concat, activation='int16')
mp_configurator.set_precision(sim.model.fc3, activation='int16', param={'weight': 'int16'})
mp_configurator.apply()

for module in sim.model.modules():
if isinstance(module, BaseQuantizationMixin):
for q in module.input_quantizers + module.output_quantizers + module.param_quantizers.values():
if q:
assert q.bitwidth == 16