Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Combine amax reduction calls #163

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,20 @@ def swap_linear_with_float8_linear(
swap_linear_with_float8_linear(child, module, emulate)


def get_float8_layers(model: torch.nn.Module, fp8_classes=None):
if fp8_classes is None:
fp8_classes = Float8Linear

# Get all fp8 layers and tensors
fp8_layers = [
child for name, child in model.named_modules() if isinstance(child, fp8_classes)
]

return fp8_layers


def sync_float8_amax_and_scale_history(
model: torch.nn.Module, fp8_classes=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probs document this function and also what the new args do / how they should be used, since I assume you get all the fp8_layers once and then pass that in every iteration

model: torch.nn.Module, fp8_classes=None, fp8_layers=None
) -> None:
"""
Manages the float8 amax and scale bookkeeping. In detail, it does the
Expand All @@ -103,27 +115,54 @@ def sync_float8_amax_and_scale_history(

Args:
model (torch.nn.Module): The model to track amaxes for
fp8_classes (optional): The fp8 classes to look for in the model.
The default is Float8Linear.
When using with TP, users can pass in the customized TP classes instead.
fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored,
and we loop over all fp8_layers to sync and update amax scale histories.
Users can use get_float8_layers to get all fp8 layers.
"""

# For now, this is written in a naive way to maximize code readability.
# TODO(future): benchmark and optimize as needed, we can combine all
# the reductions into one and probably make the history update faster.
# Lazy import to avoid circular dependency

if fp8_classes is None:
fp8_classes = Float8Linear

for name, child in model.named_modules():
if not isinstance(child, fp8_classes):
continue
# TODO(future): benchmark and optimize as needed, we have combined all
# the reductions into one and we can probably try other optimizatons to
# make the history update faster.

if fp8_layers is None:
fp8_layers = get_float8_layers(model, fp8_classes)

if dist.is_initialized():
fp8_amax_x_tensor = torch.tensor(
[child.fp8_amax_x for child in fp8_layers],
dtype=torch.float32,
device="cuda",
requires_grad=False,
)
fp8_amax_w_tensor = torch.tensor(
[child.fp8_amax_w for child in fp8_layers],
dtype=torch.float32,
device="cuda",
requires_grad=False,
)
fp8_amax_dL_dY_tensor = torch.tensor(
[child.fp8_amax_dL_dY for child in fp8_layers],
dtype=torch.float32,
device="cuda",
requires_grad=False,
)
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX)
dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX)
dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX)

for idx in range(len(fp8_layers)):
child = fp8_layers[idx]
#
# 1. in distributed contexts, syncs amax values across workers
#
if dist.is_initialized():
dist.all_reduce(child.fp8_amax_x, op=dist.ReduceOp.MAX)
dist.all_reduce(child.fp8_amax_w, op=dist.ReduceOp.MAX)
dist.all_reduce(child.fp8_amax_dL_dY, op=dist.ReduceOp.MAX)
child.fp8_amax_x = fp8_amax_x_tensor[idx]
child.fp8_amax_w = fp8_amax_w_tensor[idx]
child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx]

#
# 2. adds the `amax` values to history
Expand Down