Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 791b2bd

Browse files
generatedunixname89002005367269facebook-github-bot
generatedunixname89002005367269
authored andcommitted
Daily arc lint --take BLACK
Reviewed By: martintrojer Differential Revision: D50790931 fbshipit-source-id: 45afa339d95f2fef1c63e71572b5d93de8c5b582
1 parent 429a313 commit 791b2bd

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

float8_experimental/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Lets define a few top level things here
2-
from float8_experimental.float8_tensor import Float8Tensor
32
from float8_experimental.float8_linear import Float8Linear
3+
from float8_experimental.float8_tensor import Float8Tensor
44

55
__all__ = ["Float8Tensor", "Float8Linear"]

float8_experimental/float8_linear.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ class DelayedScalingRecipe:
174174

175175
class Float8LinearMixin(object):
176176
def __init__(self, *args, **kwargs):
177-
delayed_scaling_recipe = kwargs.pop("delayed_scaling_recipe", DelayedScalingRecipe())
177+
delayed_scaling_recipe = kwargs.pop(
178+
"delayed_scaling_recipe", DelayedScalingRecipe()
179+
)
178180
super().__init__(*args, **kwargs)
179181

180182
# TODO(future): have a unique recipe per buffer instead of one per
@@ -268,7 +270,9 @@ def cast_y_to_float8_in_bw(self, y):
268270

269271
def float8_mm(self, x_fp8, w_fp8, is_amax_initialized):
270272
scale_fn_name = self.recipe.scale_fn_name
271-
y = float8_linear.apply(x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate)
273+
y = float8_linear.apply(
274+
x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate
275+
)
272276
return y
273277

274278
def float8_pre_forward(self, x):
@@ -407,7 +411,9 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module) -> None:
407411
#
408412
_update_history_with_new_amax(child.fp8_amax_x, child.fp8_amax_history_x)
409413
_update_history_with_new_amax(child.fp8_amax_w, child.fp8_amax_history_w)
410-
_update_history_with_new_amax(child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY)
414+
_update_history_with_new_amax(
415+
child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY
416+
)
411417

412418
#
413419
# 3. calculate the scales

float8_experimental/float8_tensor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,13 @@ def backward(ctx, g):
7575
return g, None, None, None
7676

7777

78-
def to_float8(tensor: torch.Tensor, scale: torch.Tensor, float8_dtype: torch.dtype, amax_buffer:torch.Tensor =None) -> "Float8Tensor":
79-
""" Converts a higher precision tensor to float8 in a differentiable way.
78+
def to_float8(
79+
tensor: torch.Tensor,
80+
scale: torch.Tensor,
81+
float8_dtype: torch.dtype,
82+
amax_buffer: torch.Tensor = None,
83+
) -> "Float8Tensor":
84+
"""Converts a higher precision tensor to float8 in a differentiable way.
8085
8186
Args:
8287
tensor: the tensor to convert
@@ -89,6 +94,7 @@ def to_float8(tensor: torch.Tensor, scale: torch.Tensor, float8_dtype: torch.dty
8994
"""
9095
return ToFloat8ConstrFunc.apply(tensor, scale, float8_dtype, amax_buffer)
9196

97+
9298
class FromFloat8ConstrFunc(torch.autograd.Function):
9399
"""
94100
A differentiable conversion from fp8

0 commit comments

Comments
 (0)