Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Combine amax reduction calls #163

Closed
wants to merge 1 commit into from
Closed

Combine amax reduction calls #163

wants to merge 1 commit into from

Conversation

y-sq
Copy link
Contributor

@y-sq y-sq commented Dec 15, 2023

Add an option to combine the amax sync reduction (Use combine-reduction as the default behavior)

  • Combine the reduction call of each type amax scaling factor (totally 3 all_reduce calls). We can also further combine them into one single call.
    • Verified other tests can still pass. So we don't need to change existing benchmark code.
      • pytest test/test_base.py
      • ./test/test_fsdp.sh
  • Tested the new option using small llama models with 8 fsdp groups. Time taken by sync_float8_amax_and_scale_history reduced from 29ms[1] to 10ms[2].

[1] Traces without combine reduction, https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.138932292910521.json.gz&bucket=acadia
[2] https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.202842416426594.json.gz&bucket=acadia
* Results from trace[2] was updated to the correct number.
** Need Meta internal access to open these traces.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 15, 2023
@y-sq y-sq changed the base branch from bench-multi-gpu to main December 18, 2023 22:58
@y-sq y-sq force-pushed the combine-reduction branch from 4c6badd to 174b08a Compare December 18, 2023 22:58
@y-sq y-sq marked this pull request as ready for review December 18, 2023 22:59
@facebook-github-bot
Copy link
Contributor

@y-sq has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@y-sq y-sq changed the title conbine-reduction Combine amax reduction calls Dec 18, 2023
fp8_layers = get_float8_layers(model, fp8_classes)

if dist.is_initialized():
# TODO: Testing if combine_reduction improves performance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like ti does right so can remove, should we make this the default behavior?

device="cuda",
requires_grad=False,
)
# print("fp8_amax_x_tensor, ", fp8_amax_x_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: probs remove right?

def sync_float8_amax_and_scale_history(
model: torch.nn.Module, fp8_classes=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probs document this function and also what the new args do / how they should be used, since I assume you get all the fp8_layers once and then pass that in every iteration

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small comments and I think you need to run ufmt format . but otherwise awesome speed ups!!

)
# print("fp8_amax_x_tensor, ", fp8_amax_x_tensor)
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX)
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8_amax_w_tensor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh damn good catch

# print("fp8_amax_x_tensor, ", fp8_amax_x_tensor)
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX)
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX)
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dL_dY?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!!!

def sync_float8_amax_and_scale_history(
model: torch.nn.Module, fp8_classes=None
model: torch.nn.Module, fp8_classes=None, fp8_layers=None, combine_reduction=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we measure time on single GPU, and if that's a net positive as well just delete the old code path? It would be great to keep things simple.

Copy link
Contributor Author

@y-sq y-sq Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean the combine_reduction option or the fp8_layers option?

For combine_reduction, I think we can remove it and keep combine_reduction=True as default.

For fp_layers, we have many exiting tests and benchmarks (such as single-gpu llama_7B benchmarks) that use the original call (sync_float8_amax_and_scale_history(model)). So I kept it as optional and can support None

@y-sq y-sq force-pushed the combine-reduction branch 4 times, most recently from 160a899 to 435df2d Compare December 19, 2023 00:21
@y-sq
Copy link
Contributor Author

y-sq commented Dec 19, 2023

Updates:

  • Add comments and remove unused comments
  • Remove the option of "combine_reduction", always combine reduction instead
  • Fix to use the correct scaling factors of fp8_amax_w_tensor and fp8_amax_dL_dY_tensor

@facebook-github-bot
Copy link
Contributor

@y-sq has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@y-sq
Copy link
Contributor Author

y-sq commented Dec 19, 2023

I ran the format check on my devgpu server, which didn't give any errors:

$ ufmt check .
✨ 22 files already formatted ✨

However, the check on github still failed.

Summary:
~~Add an option to combine the amax sync reduction~~ (Use combine-reduction as the default behavior)
- Combine the reduction call of each type amax scaling factor (totally 3 all_reduce calls). We can also further combine them into one single call.
  - Verified other tests can still pass. So we don't need to change existing benchmark code.
    - pytest test/test_base.py
    - ./test/test_fsdp.sh 
- Tested the new option using small llama models with 8 fsdp groups. Time taken by sync_float8_amax_and_scale_history reduced from 29ms[1] to 3ms[2].

[1] Traces without combine reduction, https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.138932292910521.json.gz&bucket=acadia
[2] https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.202842416426594.json.gz&bucket=acadia
\* Trace[2] was updated after addressing the comments.
\*\* Need Meta internal access to open these traces.


Reviewed By: drisspg

Differential Revision: D52271595

Pulled By: y-sq
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D52271595

@facebook-github-bot
Copy link
Contributor

@y-sq merged this pull request in b099049.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants