-
Notifications
You must be signed in to change notification settings - Fork 416
[float8] add float8 rowwise MoE prototype #1245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks! The UI makes sense to me. |
torchtitan/config_manager.py
Outdated
@@ -465,6 +465,12 @@ class Float8: | |||
Not compatible with torch.compile. | |||
""" | |||
|
|||
moe_fqns: list[str] | str = field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add "prototype" to the field name and add a link to the README in the docstring
@@ -465,6 +465,13 @@ class Float8: | |||
Not compatible with torch.compile. | |||
""" | |||
|
|||
moe_fqns_prototype: list[str] | str = field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to add "prototype" to config name?
moe_fqns_prototype: list[str] | str = field(default_factory=list) | |
moe_fqns: list[str] | str = field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo requested "prototype" be in the field name here. Unless I misunderstood the suggestion?
Alternatively we could omit "prototype" from the field name and just make sure the docstring/help text is very clear it is a prototype feature with limitations.
For context, I don't plan to land this until at least FSDP is supported (ideally TP as well).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK either way then. Also since this is an experiment folder, everything could be experimental.
@@ -69,3 +69,4 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas | |||
enable_fsdp_float8_all_gather = false | |||
precompute_float8_dynamic_scale_for_fsdp = false | |||
filter_fqns = ["output", "router.gate"] | |||
moe_fqns = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's put something in the list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added "experts"
as the default value (this is what I've been testing with).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, had two more comments
@@ -69,3 +69,4 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas | |||
enable_fsdp_float8_all_gather = false | |||
precompute_float8_dynamic_scale_for_fsdp = false | |||
filter_fqns = ["output", "router.gate"] | |||
moe_fqns = ["experts"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to capture the shared expert? If so may need to use "expert" instead of "experts"
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/llama4/model/moe.py#L204
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is well-tested, let's put it into the other toml configs as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet, this is intentional - the routed experts work with FSDP and TP, but shared expert only works with FSDP right now. Still debugging an issue related to shared expert + TP.
# Summary - Adds `--float8.moe_fqns_prototype="..."` option to float8 training API - API accepts a comma-separated list of FQNs to apply MoE float8 training conversion to. - `quanttize_` with the `MoETrainingConfig` will recursively swap nn.Parameter data tensors to a tensor subclass, which has an override for grouped_mm => [dynamic quant + scaled grouped mm](https://github.com/pytorch/ao/blob/d963a8840e3c228e303fe14aff5d9be7017c92b6/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py#L20) prototype. Context: see implementation of GroupedExperts [here](https://github.com/pytorch/torchtitan/blob/ca10545e41582fed4ebb00db4c13db71194a0dfa/torchtitan/experiments/llama4/model/moe.py#L85-L87). # Testing - Tested via manual testing with torchao `convert_moe_to_float8_training` prototype ([PR](pytorch/ao#2275)) and confirmed single GPU training works as expected. # Limitations - Only supports single GPU training so far. - Only performs grouped_mm override for routed experts (see condition [here](https://github.com/pytorch/ao/pull/2275/files#diff-c529b94621368096076db6bec8a6fc058d7f7595c39cd59965c657ed5dea861cR29-R33)). For shared experts, I'll need to update the torchao prototype to support 3d A tensor (see torchtitan [here](https://github.com/pytorch/torchtitan/blob/ca10545e41582fed4ebb00db4c13db71194a0dfa/torchtitan/experiments/llama4/model/moe.py#L316)).
Summary
--float8.moe_fqns_prototype="..."
option to float8 training APIquanttize_
with theMoETrainingConfig
will recursively swap nn.Parameter data tensors to a tensor subclass, which has an override for grouped_mm => dynamic quant + scaled grouped mm prototype. Context: see implementation of GroupedExperts here.Testing
convert_moe_to_float8_training
prototype (PR) and confirmed single GPU training works as expected.Limitations