Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

rename all variables to use input/weight/grad_output notation #335

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def backward(ctx, gradY):
gradY_scale,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.DL_DY,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
)
return fp8_tensor, None

Expand All @@ -51,7 +51,7 @@ def cast_to_float8_e4m3_dynamic(
inpt_tensor: torch.Tensor,
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.X,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
Expand Down
76 changes: 39 additions & 37 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def backward(ctx, go):
fp8_scale_grad_output,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.DL_DY,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads
Expand Down Expand Up @@ -273,21 +273,21 @@ def convert_amax_buffer_to_float32(self):
if self._buffers[key] is not None:
self._buffers[key] = self._buffers[key].to(torch.float32)

def cast_x_to_float8(
self, x: torch.Tensor, is_amax_initialized: bool
def cast_input_to_float8(
self, input: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
x = x.to(autocast_dtype)
input = input.to(autocast_dtype)

if self.scaling_type_input is TensorScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
x,
input,
self.fp8_amax_input,
self.fp8_amax_history_input,
self.fp8_scale_input,
Expand All @@ -296,29 +296,29 @@ def cast_x_to_float8(
is_amax_initialized,
reduce_amax=True,
)
x_fp8 = Float8Tensor.to_float8(
x,
input_fp8 = Float8Tensor.to_float8(
input,
self.fp8_scale_input,
e4m3_dtype,
self.fp8_amax_input,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.X,
gemm_input_role=GemmInputRole.INPUT,
)
else:
assert self.scaling_type_input is TensorScalingType.DYNAMIC
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config)
return x_fp8
input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config)
return input_fp8

def cast_w_to_float8(
self, w: torch.Tensor, is_amax_initialized: bool
def cast_weight_to_float8(
self, weight: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
if self.scaling_type_weight is TensorScalingType.DELAYED:
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
weight_fp8 = self.weight
else:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
weight,
self.fp8_amax_weight,
self.fp8_amax_history_weight,
self.fp8_scale_weight,
Expand All @@ -328,29 +328,31 @@ def cast_w_to_float8(
reduce_amax=False,
)

w_fp8 = Float8Tensor.to_float8(
w,
weight_fp8 = Float8Tensor.to_float8(
weight,
self.fp8_scale_weight,
e4m3_dtype,
self.fp8_amax_weight,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
)
else:
assert self.scaling_type_weight is TensorScalingType.DYNAMIC
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
weight_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(
self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W
weight_fp8 = cast_to_float8_e4m3_dynamic(
self.weight,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return w_fp8
return weight_fp8

def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is TensorScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
y = NoopFwToFloat8E5M2Bw.apply(
y,
output = NoopFwToFloat8E5M2Bw.apply(
output,
self.fp8_amax_grad_output,
self.fp8_amax_history_grad_output,
self.fp8_scale_grad_output,
Expand All @@ -360,10 +362,10 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
)
else:
assert self.scaling_type_grad_output is TensorScalingType.DYNAMIC
y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config)
return y
output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config)
return output

def float8_pre_forward(self, x):
def float8_pre_forward(self, input):
if not self.enable_pre_and_post_forward:
return
if (
Expand All @@ -374,7 +376,7 @@ def float8_pre_forward(self, x):
raise AssertionError(
"amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward"
)
self.last_seen_input_dtype = x.dtype
self.last_seen_input_dtype = input.dtype

def float8_post_forward(self):
if not self.enable_pre_and_post_forward:
Expand All @@ -388,25 +390,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.has_any_delayed_scaling:
self.float8_pre_forward(input)

x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized)

y = torch.matmul(x_fp8, w_fp8.t())
output = torch.matmul(input_fp8, weight_fp8.t())

# Cast gradY to float8_e5m2 during backward
y = self.cast_y_to_float8_in_bw(y)
# Cast grad_output to float8_e5m2 during backward
output = self.cast_output_to_float8_in_bw(output)

if self.bias is not None:
y = y + self.bias.to(y.dtype)
output = output + self.bias.to(output.dtype)

if self.has_any_delayed_scaling:
self.float8_post_forward()
return y
return output

def scaling_repr(self):
# add scaling settings without using too many characters
# example: "x:del,w:del,dldy:dyn"
return f"x:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},dldy:{self.scaling_type_grad_output.short_str()}"
# example: "i:del,w:del,go:dyn"
return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}"

def extra_repr(self):
s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"'
Expand Down
55 changes: 28 additions & 27 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@
#
# There are three gemms in a forward + backward of a Linear layer:
#
# 1. x @ w_t = y (forward pass)
# 2. dL_dY @ w = dL_dX (backward pass)
# 3. x_t @ dL_dY = dL_dW (backward pass)
# 1. input @ weight_t = output (forward pass)
# 2. grad_output @ weight = grad_input (backward pass)
# 3. input_t @ grad_output = grad_weight (backward pass)
#
# In the formulas above, there are:
# A. six input tensors (x, x_t, w, w_t, dL_dY, dL_dY_t).
# - Note that dL_dY_t is implied because of memory format requirements
# A. six input tensors (input, input_t, weight, weight_t, grad_output, grad_output_t).
# - Note that grad_output_t is implied because of memory format requirements
# of float8 gemms
# B. three output tensors (y, dL_dX, dL_dW)
# B. three output tensors (output, grad_input, grad_weight)
#
# We want each input tensor, gemm, and output tensor to be configurable.
# The state of this configuration today is:
#
# i. pairs of input tensors (non-t and t variants) have their scaling
# configurable via the scaling_type_{x_w_dL_dY} arguments to Float8Linear
# configurable via the scaling_type_* arguments to Float8Linear
# ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing
# iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed
# to configure all three gemms, also not user facing
Expand All @@ -60,11 +60,12 @@

# The object below is not user facing and exists for convenience,
# to allow Float8Tensor to use
# the right config based on which gemm from `y`, `dL_dX`, `dL_dW` is
# the right config based on which gemm from gemms with outputs
# `output`, `grad_input`, `grad_weight` is
# being called.
LinearMMConfig = namedtuple(
"LinearMMConfig",
["y", "dL_dX", "dL_dW"],
["output", "grad_input", "grad_weight"],
defaults=[
ScaledMMConfig(False, True, False, False),
ScaledMMConfig(False, False, False, False),
Expand All @@ -81,9 +82,9 @@ class GemmInputRole(enum.Enum):
gemm is performed.
"""

X = "x"
W = "w"
DL_DY = "dL_dY"
INPUT = "input"
WEIGHT = "weight"
GRAD_OUTPUT = "grad_output"


# choose which scaled_mm_config to use based on gemm inputs
Expand All @@ -93,21 +94,21 @@ def choose_scaled_mm_config(
b_role: GemmInputRole,
b_linear_mm_config: LinearMMConfig,
):
if a_role is GemmInputRole.X and b_role is GemmInputRole.W:
if a_role is GemmInputRole.INPUT and b_role is GemmInputRole.WEIGHT:
assert (
a_linear_mm_config.y == b_linear_mm_config.y
), f"linear_mm_config.y mismatch: {a_linear_mm_config.y} vs {b_linear_mm_config.y}"
return a_linear_mm_config.y
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.W:
a_linear_mm_config.output == b_linear_mm_config.output
), f"linear_mm_config.output mismatch: {a_linear_mm_config.output} vs {b_linear_mm_config.output}"
return a_linear_mm_config.output
elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.WEIGHT:
assert (
a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX
), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}"
return a_linear_mm_config.dL_dX
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X:
a_linear_mm_config.grad_input == b_linear_mm_config.grad_input
), f"linear_mm_config.grad_input mismatch: {a_linear_mm_config.grad_input} vs {b_linear_mm_config.grad_input}"
return a_linear_mm_config.grad_input
elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.INPUT:
assert (
a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW
), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}"
return a_linear_mm_config.dL_dW
a_linear_mm_config.grad_weight == b_linear_mm_config.grad_weight
), f"linear_mm_config.grad_weight mismatch: {a_linear_mm_config.grad_weight} vs {b_linear_mm_config.grad_weight}"
return a_linear_mm_config.grad_weight
else:
raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}")

Expand Down Expand Up @@ -207,7 +208,7 @@ def forward(
float8_dtype=e4m3_dtype,
amax_buffer: Optional[torch.Tensor] = None,
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
"""Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer.
Args
Expand Down Expand Up @@ -287,7 +288,7 @@ def __new__(
scale: torch.Tensor,
orig_dtype: torch.dtype,
linear_mm_config: Optional[LinearMMConfig],
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
assert (
scale.numel() == 1
Expand Down Expand Up @@ -348,7 +349,7 @@ def to_float8(
float8_dtype: torch.dtype,
amax_buffer: Optional[torch.Tensor] = None,
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
"""Converts a higher precision tensor to float8 in a differentiable way.

Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _prepare_input_fn(
input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.X,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)

# transform the input layouts to the desired layouts of ColwiseParallel
Expand Down Expand Up @@ -101,7 +101,7 @@ def _prepare_input_fn(
input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.X,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)

if input_layouts != desired_input_layouts:
Expand Down Expand Up @@ -199,7 +199,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
dt_inp = cast_to_float8_e4m3_dynamic(
dt_inp,
self.linear_mm_config,
gemm_input_role=GemmInputRole.X,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
if desired_layout is not None and input_layout != desired_layout:
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
Expand Down
10 changes: 5 additions & 5 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,14 @@ def fsdp_pre_all_gather(self, mesh):
self._precomputed_scale,
torch.float8_e4m3fn,
linear_mm_config=self._linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
)
else:
float8_tensor = cast_to_float8_e4m3_dynamic(
self._tensor,
self._linear_mm_config,
reduce_amax=True,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
)
return (float8_tensor._data,), (float8_tensor._scale,)

Expand All @@ -201,7 +201,7 @@ def fsdp_post_all_gather(
scale,
param_dtype,
self._linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
), (data,)


Expand Down Expand Up @@ -364,7 +364,7 @@ def fsdp_pre_all_gather(self, mesh):
e4m3_dtype,
self._amax_buffer,
self._linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
)
return (float8_tensor._data,), (float8_tensor._scale,)

Expand All @@ -387,5 +387,5 @@ def fsdp_post_all_gather(
scale,
param_dtype,
self._linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
), (data,)
4 changes: 2 additions & 2 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
scale,
dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
)
self.weight = nn.Parameter(quantized_weight)
self.weight.requires_grad = False
Expand Down Expand Up @@ -205,7 +205,7 @@ def cast_to_float8_e4m3_inference(
scale,
e4m3_dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=GemmInputRole.X,
gemm_input_role=GemmInputRole.INPUT,
)


Expand Down
Loading
Loading