Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 8ed0eb7

Browse files
vkuzofacebook-github-bot
authored andcommitted
remove unused FSDP float8 all-gather integration (#182)
Summary: All of the removed code is outdated, we will try again with per-parameter sharding FSDP. Pull Request resolved: #182 Test Plan: ``` ./test/test_everything.sh ``` Reviewed By: awgu Differential Revision: D52648709 Pulled By: vkuzo fbshipit-source-id: 202c2675e4e96a035d5f345346f1a1a3816f02ae
1 parent 21c7423 commit 8ed0eb7

File tree

4 files changed

+4
-50
lines changed

4 files changed

+4
-50
lines changed

float8_experimental/dynamic_linear/dynamic_float8_linear.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,9 @@ class Float8DynamicLinear(torch.nn.Linear):
4646
conversion to fp8 of the input and weight tensors.
4747
"""
4848

49-
def __init__(self, *args, **kwargs):
50-
super().__init__(*args, **kwargs)
51-
self.add_weight_tag()
52-
5349
def forward(self, x):
5450
x_fp8 = self.cast_to_float8(x)
55-
if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast
56-
w_fp8 = self._w_fp8
57-
else:
58-
w_fp8 = self.cast_to_float8(self.weight)
51+
w_fp8 = self.cast_to_float8(self.weight)
5952

6053
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
6154

@@ -64,11 +57,6 @@ def forward(self, x):
6457

6558
return y
6659

67-
def add_weight_tag(self):
68-
# We add a tag to the weight nn.Parameter in order to signal
69-
# To FSDP that this param is a weight
70-
self.weight._is_fp8_weight = True
71-
7260
def cast_to_float8(self, inpt_tensor):
7361
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
7462
return Float8Tensor.to_float8(
@@ -92,5 +80,4 @@ def from_float(cls, mod, emulate: bool = False):
9280
new_mod.weight = mod.weight
9381
new_mod.bias = mod.bias
9482
new_mod.emulate = emulate
95-
new_mod.add_weight_tag()
9683
return new_mod

float8_experimental/float8_linear.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,6 @@ def __init__(self, *args, **kwargs):
171171
# Note: this is not used in non-TP code.
172172
self.use_sequence_parallel = False
173173

174-
# Save the Float8Tensor constructor for FSDP.
175-
# N.B. Do not partially apply the scale into the constructor because
176-
# buffer Python IDs are not preserved by `nn.Module.to()` and the
177-
# module could be moved to GPU after this constructor. Instead, FSDP
178-
# will access the scale when it has ensured that it is on GPU.
179-
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs)
180-
181174
# pre_forward and post_forward are currently broken with FSDP
182175
# and torch.compile, this option can disable them
183176
self.enable_pre_and_post_forward = config.enable_pre_and_post_forward
@@ -312,30 +305,18 @@ def float8_post_forward(self):
312305
self.is_amax_initialized = True
313306
self.amax_and_scale_synced = False
314307

315-
def add_weight_tag(self):
316-
# We add a tag to the weight nn.Parameter in order to signal
317-
# To FSDP that this param is a weight
318-
self.weight._is_fp8_weight = True
319-
320308

321309
class Float8Linear(Float8LinearMixin, torch.nn.Linear):
322310
"""
323311
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
324312
scales in way friendly to delayed scaling.
325313
"""
326314

327-
def __init__(self, *args, **kwargs):
328-
super().__init__(*args, **kwargs)
329-
self.add_weight_tag()
330-
331315
def forward(self, x):
332316
self.float8_pre_forward(x)
333317

334318
x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
335-
if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast
336-
w_fp8 = self._w_fp8
337-
else:
338-
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
319+
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
339320

340321
y = torch.matmul(x_fp8, w_fp8.t())
341322

@@ -366,5 +347,4 @@ def from_float(cls, mod, emulate: bool = False):
366347
new_mod.emulate = emulate
367348
# I think its okay to send all params and buffers to device
368349
new_mod.to(mod.weight.device)
369-
new_mod.add_weight_tag()
370350
return new_mod

float8_experimental/tp_linear.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,7 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
4949
input_parallel, self.is_amax_initialized
5050
)
5151

52-
if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast
53-
w_fp8 = self._w_fp8
54-
else:
55-
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
52+
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
5653

5754
# Matrix multiply.
5855
output_parallel = torch.matmul(input_parallel_fp8, w_fp8.t())
@@ -101,7 +98,6 @@ def from_float(cls, mod, emulate=False):
10198
device_to_use = next(mod.parameters()).device
10299
new_mod.to(device_to_use)
103100
new_mod.emulate = emulate
104-
new_mod.add_weight_tag()
105101
# TODO: test when creation is on cuda
106102
return new_mod
107103

@@ -137,10 +133,7 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
137133
input_parallel, self.is_amax_initialized
138134
)
139135

140-
if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast
141-
w_fp8 = self._w_fp8
142-
else:
143-
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
136+
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
144137

145138
# Matrix multiply.
146139
output_parallel = torch.matmul(input_parallel_fp8, w_fp8.t())
@@ -196,7 +189,6 @@ def from_float(cls, mod, emulate=False):
196189
device_to_use = next(mod.parameters()).device
197190
new_mod.to(device_to_use)
198191
new_mod.emulate = emulate
199-
new_mod.add_weight_tag()
200192
return new_mod
201193

202194

test/test_base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,6 @@ def test_linear_bias(
175175
y.dtype == torch.bfloat16
176176
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
177177

178-
def test_linear_float8_weight_tag(self) -> None:
179-
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
180-
m_fp8 = Float8Linear.from_float(copy.deepcopy(m_ref))
181-
assert m_fp8.weight._is_fp8_weight
182-
183178
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
184179
@pytest.mark.parametrize(
185180
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]

0 commit comments

Comments
 (0)