Skip to content

float8 training axiswise scaling support with per-gemm-argument configuration #940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 51 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
183dbec
Update
vkuzo Sep 23, 2024
241f815
Update
vkuzo Sep 23, 2024
d0b1002
Update
vkuzo Sep 23, 2024
f15c2a0
Update
vkuzo Sep 23, 2024
c816ce9
Update
vkuzo Sep 23, 2024
9150b4f
Update
vkuzo Sep 23, 2024
40279fb
Update
vkuzo Sep 23, 2024
459e92c
Update
vkuzo Sep 23, 2024
732b231
Update
vkuzo Sep 23, 2024
e7c15d1
Update
vkuzo Sep 24, 2024
1d01df3
Update
vkuzo Sep 25, 2024
62acdaf
Update
vkuzo Sep 25, 2024
381e16e
Update
vkuzo Sep 27, 2024
afdf660
Update
vkuzo Sep 27, 2024
d53f2ce
Update
vkuzo Sep 27, 2024
0737eb8
Update
vkuzo Sep 27, 2024
2791eb3
Update
vkuzo Sep 27, 2024
6cfd1cd
Update
vkuzo Sep 27, 2024
94907e5
Update
vkuzo Sep 27, 2024
423760a
Update
vkuzo Sep 27, 2024
552db23
Update
vkuzo Sep 27, 2024
953bc2f
Update
vkuzo Sep 27, 2024
c5d19e0
Update
vkuzo Sep 27, 2024
24f0f3b
Update
vkuzo Sep 27, 2024
7e0fe97
Update
vkuzo Sep 27, 2024
10f2e0f
Update
vkuzo Sep 27, 2024
4437054
Update
vkuzo Sep 30, 2024
31a017b
Update
vkuzo Sep 30, 2024
743b4c1
Update
vkuzo Sep 30, 2024
0c473c4
Update
vkuzo Oct 2, 2024
c1be278
Update
vkuzo Oct 2, 2024
179e3b3
Update
vkuzo Oct 2, 2024
fc8d4ef
Update
vkuzo Oct 2, 2024
76322de
Update
vkuzo Oct 2, 2024
4eec2bd
Update
vkuzo Oct 2, 2024
ac6f768
Update
vkuzo Oct 2, 2024
02b9fca
Update
vkuzo Oct 2, 2024
5756aa5
Update
vkuzo Oct 2, 2024
1f01df9
Update
vkuzo Oct 4, 2024
b9bcc30
Update
vkuzo Oct 4, 2024
e595f30
Update
vkuzo Oct 4, 2024
4bb59a6
Update
vkuzo Oct 4, 2024
c1c218f
Update
vkuzo Oct 4, 2024
024fe94
Update
vkuzo Oct 4, 2024
4027694
Update
vkuzo Oct 4, 2024
076de91
Update
vkuzo Oct 4, 2024
ca127f0
Update
vkuzo Oct 4, 2024
712fd5d
Update
vkuzo Oct 5, 2024
d70326c
Update
vkuzo Oct 7, 2024
f2d104a
Update
vkuzo Oct 7, 2024
b536435
Update
vkuzo Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Sep 23, 2024
commit 241f815f73b85f27a59fdf9f8e9d0a80cedb5f2c
32 changes: 27 additions & 5 deletions benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

import torch
import torch.utils.benchmark as benchmark
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
ScalingGranularity,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
linear_requires_sync,
Expand Down Expand Up @@ -107,35 +112,49 @@ def main(
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
):
device = "cuda"
print(f"Compile is set to | {compile}")

scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(scaling_type=scaling_type_input)
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down Expand Up @@ -167,7 +186,7 @@ def main(
copy.deepcopy(linear_ref),
config=config,
)
scaling_repr = linear_float8.scaling_repr()
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
Expand Down Expand Up @@ -310,6 +329,7 @@ def invoke_main() -> None:
parser.add_argument("--scaling_type_input", type=str, required=False)
parser.add_argument("--scaling_type_weight", type=str, required=False)
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
parser.add_argument("--scaling_granularity", type=str, required=False)
args = parser.parse_args()
output_path = Path(args.output_path) if args.output_path is not None else None
kwargs = {}
Expand All @@ -327,6 +347,8 @@ def invoke_main() -> None:
kwargs["scaling_type_weight"] = args.scaling_type_weight
if args.scaling_type_grad_output is not None:
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
if args.scaling_granularity is not None:
kwargs["scaling_granularity"] = args.scaling_granularity
main(
output_path,
not args.disable_compile,
Expand Down
27 changes: 23 additions & 4 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
ScalingGranularity,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
Expand Down Expand Up @@ -252,6 +257,7 @@ def main(
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
model_type: str = "linear",
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
Expand All @@ -263,28 +269,41 @@ def main(
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(scaling_type=scaling_type_input)
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down
32 changes: 29 additions & 3 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ def _test_linear_impl(
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.parametrize(
"scaling_granularity",
[ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE],
)
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("linear_bias", [False, True])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
Expand All @@ -337,6 +341,7 @@ def test_linear(
scaling_type_input: ScalingType,
scaling_type_weight: ScalingType,
scaling_type_grad_output: ScalingType,
scaling_granularity: ScalingGranularity,
linear_dtype: torch.dtype,
linear_bias: bool,
):
Expand All @@ -349,30 +354,51 @@ def test_linear(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()
if scaling_granularity is ScalingGranularity.AXISWISE:
if (
scaling_type_input != ScalingType.DYNAMIC or
scaling_type_weight != ScalingType.DYNAMIC or
scaling_type_grad_output != ScalingType.DYNAMIC or
linear_dtype != torch.bfloat16
):
pytest.skip()

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)

if scaling_type_input is ScalingType.STATIC:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_input = CastConfig(scaling_type=scaling_type_input)
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_weight = CastConfig(scaling_type=scaling_type_weight)
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output)
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down
Loading
Loading