Skip to content

Qualcomm AI Engine Direct - Add submodule quant config setting #9355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 10, 2025

Conversation

chunit-quic
Copy link
Collaborator

  • Add API to qnn quantizer for setting submodule quant config
  • Refine QnnQuantizer setting functions

Copy link

pytorch-bot bot commented Mar 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9355

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 137228a with merge base 6adff9c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 18, 2025
@chunit-quic
Copy link
Collaborator Author

@pytorchbot label "release notes: qualcomm"

@jackzhxng jackzhxng removed their request for review March 18, 2025 17:38
@swolchok swolchok removed their request for review March 18, 2025 21:04
@cccclai
Copy link
Contributor

cccclai commented Mar 24, 2025

@sxu @billmguo can you help reviewing this pr and see if it meets the need?

@cccclai cccclai requested review from sxu and billmguo March 24, 2025 19:36

nn_module_stack = node.meta.get("nn_module_stack")
if nn_module_stack:
module_source_str, module_str = list(nn_module_stack.values())[-1][
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indexing by -1 here, is the order of the dict guaranteed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sxu,

Thank you for reviewing!
Yes, based on the description nn_module_stack from pytorch.org, the order is determined by the stack trace. We can find the same the -1 logic also applied in the other file within the repo.

module_source_str, module_str = list(nn_module_stack.values())[-1][
-1
].rsplit(".", 1)
module_source = importlib.import_module(module_source_str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module might not be available for import. For example when the output of torch.export.export is done ahead of time. If I'm not mistaken loading the exported program doesn't require source modules to be available.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please provide a failed example? To my knowledge, you must have a Torch neural network to invoke torch.export.export. To have a Torch neural network, you must import the corresponding modules. Therefore, it seems to me that we are guaranteed to have the module available for import here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick update. Thanks to your comments, we found that in some cases, certain modules failed to import. Switching to string-based comparison will bypass this issue.
One more thing, could you please share more thoughts on this comment? Thank you :)

self.per_channel_quant_config = get_ptq_per_channel_quant_config()
self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
self.default_quant_config = ModuleQConfig()
self.module_qconfig_dict: Dict[torch.nn.Module, ModuleQConfig] = {}
Copy link
Contributor

@sxu sxu Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for setting different qconfigs by submodule type, but another important use case is by subodule name. For example both self.backbone and self.head contain linear layers, but we want 8-bit activation for backbone and 16-bit for head.

How about making this a List[Tuple[Callable[[fx.Node], bool], ModelQConfig]]? First element of the tuple is a predicate determining if the qconfig should be applied to a given node. Each predicate in the list is evaluated sequentially, the earlier ones will have priority over later ones. Then there can be some utility for creating common predicates, for example:

def get_submodule_type_predicate(module_type):
    def predicate(node):
        if nn_module_stack := node.meta.get("nn_module_stack"):
            return module_type in (x[1] for x in nn_module_stack.values())
        return False

    return predicate

or

def get_submodule_name_predicate(module_name):
    def predicate(node):
        if nn_module_stack := node.meta.get("nn_module_stack"):
            return module_name in (x[1] for x in nn_module_stack.keys())
        return False

    return predicate

which can be used like this:

predicate_qconfigs = [
    (get_submodule_name_predicate("self.head"), my_16a8w_qconfig),
    (get_submodule_type_predicate("torch.nn.Linear"), my_8a8w_qconfig),
    (lambda n: True, default_qconfig),
]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for providing the use cases and pseudo codes!
They clarify some of the requirements we were unsure about. I have a few questions to ensure we fully understand the requirements:

  1. So, basically, all submodules belonging to the given name or submodule type should share the same qconfig, correct?
  2. Is there any data structure more robust than a string? We find it challenging to find one after export. As in your example, it seems we can only rely on the string from nn_module_stack, which leads us to use importlib for additional assurance.
  3. Similar to point 2, the name of self.head is defined by users in nn.Module. Therefore, an arbitrary name might appear in any module, which could cause us accidentally set qconfig to an unintended target. Ask more information (like what moudle type user want) can work around this. Yet I guess we might need a more reliable data structure to search for them in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick update. We change the mapping way based on your comments. Please feel free to let us know any problem. Thank you. :D

@chunit-quic chunit-quic force-pushed the add_submodule_quant branch from 1c8f72e to 89673b6 Compare April 2, 2025 02:46
Copy link
Contributor

@sxu sxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this for the last couple of days, but thanks for the follow-up! LGTM.

Chun-I Tsai added 3 commits April 7, 2025 09:20
- Add API to qnn quantizer for setting submodule quant config
- Change to string based way to set up qconfig for submodule
@chunit-quic chunit-quic force-pushed the add_submodule_quant branch from 89673b6 to f2ce0e7 Compare April 7, 2025 01:23
@cccclai cccclai added the release notes: qualcomm Changes to the Qualcomm backend delegate label Apr 7, 2025
@cccclai
Copy link
Contributor

cccclai commented Apr 7, 2025

There is a linterror, can you fix it?

@chunit-quic
Copy link
Collaborator Author

There is a linterror, can you fix it?

Thanks for pointing that out! Fixed.

@chunit-quic
Copy link
Collaborator Author

Hi @cccclai, if this PR looks good to you, could you please help merge it? Thank you. :)

@cccclai
Copy link
Contributor

cccclai commented Apr 10, 2025

Yes, sorry for being late, merging.

@cccclai cccclai merged commit 1d43b3b into pytorch:main Apr 10, 2025
88 checks passed
kirklandsign pushed a commit that referenced this pull request Apr 11, 2025
- Add API to qnn quantizer for setting submodule quant config
- Refine QnnQuantizer setting functions

---------

Co-authored-by: Chun-I Tsai <chunit@qti.qualcomm.com>
keyprocedure pushed a commit to keyprocedure/executorch that referenced this pull request Apr 21, 2025
…ch#9355)

- Add API to qnn quantizer for setting submodule quant config
- Refine QnnQuantizer setting functions

---------

Co-authored-by: Chun-I Tsai <chunit@qti.qualcomm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: qualcomm Changes to the Qualcomm backend delegate
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants