Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b099049

Browse files
y-sqfacebook-github-bot
authored andcommitted
Combine amax reduction calls (#163)
Summary: ~~Add an option to combine the amax sync reduction~~ (Use combine-reduction as the default behavior) - Combine the reduction call of each type amax scaling factor (totally 3 all_reduce calls). We can also further combine them into one single call. - Verified other tests can still pass. So we don't need to change existing benchmark code. - pytest test/test_base.py - ./test/test_fsdp.sh - Tested the new option using small llama models with 8 fsdp groups. Time taken by sync_float8_amax_and_scale_history reduced from 29ms[1] to 3ms[2]. [1] Traces without combine reduction, https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.138932292910521.json.gz&bucket=acadia [2] https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.202842416426594.json.gz&bucket=acadia \* Trace[2] was updated after addressing the comments. \*\* Need Meta internal access to open these traces. Pull Request resolved: #163 Reviewed By: drisspg Differential Revision: D52271595 Pulled By: y-sq fbshipit-source-id: 65d27d32cb4d291dc6fbe62b7a916cf2e32e6482
1 parent c40de9b commit b099049

File tree

1 file changed

+53
-14
lines changed

1 file changed

+53
-14
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,20 @@ def swap_linear_with_float8_linear(
8787
swap_linear_with_float8_linear(child, module, emulate)
8888

8989

90+
def get_float8_layers(model: torch.nn.Module, fp8_classes=None):
91+
if fp8_classes is None:
92+
fp8_classes = Float8Linear
93+
94+
# Get all fp8 layers and tensors
95+
fp8_layers = [
96+
child for name, child in model.named_modules() if isinstance(child, fp8_classes)
97+
]
98+
99+
return fp8_layers
100+
101+
90102
def sync_float8_amax_and_scale_history(
91-
model: torch.nn.Module, fp8_classes=None
103+
model: torch.nn.Module, fp8_classes=None, fp8_layers=None
92104
) -> None:
93105
"""
94106
Manages the float8 amax and scale bookkeeping. In detail, it does the
@@ -103,27 +115,54 @@ def sync_float8_amax_and_scale_history(
103115
104116
Args:
105117
model (torch.nn.Module): The model to track amaxes for
118+
fp8_classes (optional): The fp8 classes to look for in the model.
119+
The default is Float8Linear.
120+
When using with TP, users can pass in the customized TP classes instead.
121+
fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored,
122+
and we loop over all fp8_layers to sync and update amax scale histories.
123+
Users can use get_float8_layers to get all fp8 layers.
106124
"""
107125

108126
# For now, this is written in a naive way to maximize code readability.
109-
# TODO(future): benchmark and optimize as needed, we can combine all
110-
# the reductions into one and probably make the history update faster.
111-
# Lazy import to avoid circular dependency
112-
113-
if fp8_classes is None:
114-
fp8_classes = Float8Linear
115-
116-
for name, child in model.named_modules():
117-
if not isinstance(child, fp8_classes):
118-
continue
127+
# TODO(future): benchmark and optimize as needed, we have combined all
128+
# the reductions into one and we can probably try other optimizatons to
129+
# make the history update faster.
130+
131+
if fp8_layers is None:
132+
fp8_layers = get_float8_layers(model, fp8_classes)
133+
134+
if dist.is_initialized():
135+
fp8_amax_x_tensor = torch.tensor(
136+
[child.fp8_amax_x for child in fp8_layers],
137+
dtype=torch.float32,
138+
device="cuda",
139+
requires_grad=False,
140+
)
141+
fp8_amax_w_tensor = torch.tensor(
142+
[child.fp8_amax_w for child in fp8_layers],
143+
dtype=torch.float32,
144+
device="cuda",
145+
requires_grad=False,
146+
)
147+
fp8_amax_dL_dY_tensor = torch.tensor(
148+
[child.fp8_amax_dL_dY for child in fp8_layers],
149+
dtype=torch.float32,
150+
device="cuda",
151+
requires_grad=False,
152+
)
153+
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX)
154+
dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX)
155+
dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX)
119156

157+
for idx in range(len(fp8_layers)):
158+
child = fp8_layers[idx]
120159
#
121160
# 1. in distributed contexts, syncs amax values across workers
122161
#
123162
if dist.is_initialized():
124-
dist.all_reduce(child.fp8_amax_x, op=dist.ReduceOp.MAX)
125-
dist.all_reduce(child.fp8_amax_w, op=dist.ReduceOp.MAX)
126-
dist.all_reduce(child.fp8_amax_dL_dY, op=dist.ReduceOp.MAX)
163+
child.fp8_amax_x = fp8_amax_x_tensor[idx]
164+
child.fp8_amax_w = fp8_amax_w_tensor[idx]
165+
child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx]
127166

128167
#
129168
# 2. adds the `amax` values to history

0 commit comments

Comments
 (0)