Skip to content

Commit

Permalink
Ruff Lint fixes (torchao/float8) (#1239)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Nov 11, 2024
1 parent addcb24 commit 299aacd
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 128 deletions.
7 changes: 3 additions & 4 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
# To add a new path: Simply add it to the 'include' list.
# Example: To lint all files in every subfolder of 'test', add "test/**/*"
include = [
"torchao/float8/inference.py",
"torchao/float8/float8_utils.py",
"torchao/float8/**/*.py",
"torchao/dtypes/nf4tensor.py",
"test/dtypes/test_nf4.py",
"torchao/float8/float8_tensor.py",
"torchao/quantization/linear_activation_weight_observer.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
"torchao/quantization/weight_tensor_linear_activation_quantization.py",
"torchao/dtypes/**/*.py",
"torchao/prototype/low_bit_optim/**.py",
"test/prototype/low_bit_optim/**.py",

]

lint.ignore = ["E731"]
6 changes: 2 additions & 4 deletions torchao/float8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Float8LinearConfig,
ScalingType,
)
from torchao.float8.float8_linear import Float8Linear, WeightWithDelayedFloat8CastTensor
from torchao.float8.float8_linear import WeightWithDelayedFloat8CastTensor
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
Expand All @@ -23,12 +23,10 @@
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.inference import Float8MMConfig
from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp

from torchao.float8.inference import Float8MMConfig
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5


if TORCH_VERSION_AT_LEAST_2_5:
# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals
Expand Down
37 changes: 23 additions & 14 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

logger: logging.Logger = logging.getLogger()


class ScalingType(enum.Enum):
DELAYED = "delayed"
DYNAMIC = "dynamic"
Expand Down Expand Up @@ -71,8 +72,10 @@ def __post_init__(self):
self.static_scale is not None
), "static_scale must be specified for static scaling"
if self.scaling_granularity is ScalingGranularity.AXISWISE:
assert self.scaling_type is ScalingType.DYNAMIC, \
"only dynamic scaling type is supported for axiswise scaling granularity"
assert (
self.scaling_type is ScalingType.DYNAMIC
), "only dynamic scaling type is supported for axiswise scaling granularity"


@dataclass(frozen=True)
class DelayedScalingConfig:
Expand Down Expand Up @@ -226,7 +229,7 @@ class Float8LinearConfig:

# If True, we only use fp8-all-gather to reduce the communication cost.
# The gemm computation is still done in the original precision.
# `cast_config_weight` is used to decide how to cast the weight to fp8,
# `cast_config_weight` is used to decide how to cast the weight to fp8,
# other casting configs will be ignored.
use_fp8_all_gather_only: bool = False

Expand All @@ -238,16 +241,23 @@ def __post_init__(self):
# to work.
# Source of hack: https://stackoverflow.com/a/65959419/
if self.cast_config_input_for_grad_weight is None:
object.__setattr__(self, "cast_config_input_for_grad_weight", self.cast_config_input)
object.__setattr__(
self, "cast_config_input_for_grad_weight", self.cast_config_input
)
if self.cast_config_weight_for_grad_input is None:
object.__setattr__(self, "cast_config_weight_for_grad_input", self.cast_config_weight)
object.__setattr__(
self, "cast_config_weight_for_grad_input", self.cast_config_weight
)
if self.cast_config_grad_output_for_grad_weight is None:
object.__setattr__(self, "cast_config_grad_output_for_grad_weight", self.cast_config_grad_output)
object.__setattr__(
self,
"cast_config_grad_output_for_grad_weight",
self.cast_config_grad_output,
)

# float8 all-gather only supports tensorwise, in the future may support blockwise
if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE:
assert not self.enable_fsdp_float8_all_gather, \
f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}"
assert not self.enable_fsdp_float8_all_gather, f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}"

# save some characters in the compatibility checks below
cc_i = self.cast_config_input
Expand All @@ -266,12 +276,13 @@ def __post_init__(self):
):
is_disabled_1 = cc1.scaling_type is ScalingType.DISABLED
is_disabled_2 = cc1.scaling_type is ScalingType.DISABLED
assert is_disabled_1 == is_disabled_2, \
f"incompatible operand precision for {gemm_name}"

assert (
is_disabled_1 == is_disabled_2
), f"incompatible operand precision for {gemm_name}"

if self.use_fp8_all_gather_only:
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"

# See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
if (
self.enable_fsdp_float8_all_gather
Expand All @@ -280,7 +291,6 @@ def __post_init__(self):
logger.warning(
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
)



# Pre-made recipes for common configurations
Expand Down Expand Up @@ -328,7 +338,6 @@ def recipe_name_to_linear_config(
)

elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:

# lw's recipe for a modification on all-axiswise:
#
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
Expand Down
1 change: 0 additions & 1 deletion torchao/float8/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any

import torch

from fairscale.nn.model_parallel.initialize import get_model_parallel_group

# from float8_tensor import Float8Tensor
Expand Down
Loading

0 comments on commit 299aacd

Please sign in to comment.