Skip to content

float8: remove unneeded kernel for scale generation #616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_tensor import ScaledMMConfig
from utils import get_name_to_shapes_iter
from tqdm import tqdm

# estimating TOPs for matmuls in fp32, fp16, fp8
Expand Down Expand Up @@ -96,6 +97,11 @@ def main(
n_limit: Optional[int] = None,
fast_accum_filter: Optional[bool] = None,
shape_name_filter: Optional[str] = None,
*,
shape_gen_name: str = 'llama',
M: Optional[int] = None,
K: Optional[int] = None,
N: Optional[int] = None,
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
Expand All @@ -112,26 +118,19 @@ def main(
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
)

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
input_bias = False
if fast_accum_filter is not None:
use_fast_accum = [fast_accum_filter]
else:
use_fast_accum = [True, False]
if shape_name_filter is not None:
k = shape_name_filter
name_to_shapes_70b = {k: name_to_shapes_70b[k]}
name_to_shapes = ((k, v) for (k, v) in name_to_shapes if k == shape_name_filter)
experiment_list: List[Experiment] = []
dtype = torch.bfloat16
for idx, (fast_accum, (name, (K, N))) in enumerate(
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items())))
for idx, (fast_accum, (name, (M, K, N))) in enumerate(
tqdm(list(product(use_fast_accum, name_to_shapes)))
):
if n_limit is not None and idx >= n_limit:
break
Expand All @@ -150,8 +149,6 @@ def main(
else:
linear_float8.forward_config = ScaledMMConfig(False, False, False)

bsz, seq_len = 4, 4096
M = bsz * seq_len
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

Expand Down Expand Up @@ -279,6 +276,10 @@ def invoke_main() -> None:
parser.add_argument("-o", "--output_path", type=str, required=False)
parser.add_argument("--disable_compile", action="store_true")
parser.add_argument("-n", "--n_limit", type=int, required=False)
parser.add_argument("--shape_gen_name", type=str, required=False)
parser.add_argument("--M", type=int, required=False)
parser.add_argument("--K", type=int, required=False)
parser.add_argument("--N", type=int, required=False)
parser.add_argument("--fast_accum_filter", type=bool, required=False)
parser.add_argument("--shape_name_filter", type=str, required=False)
parser.add_argument("--scaling_type_input", type=str, required=False)
Expand All @@ -287,6 +288,14 @@ def invoke_main() -> None:
args = parser.parse_args()
output_path = Path(args.output_path) if args.output_path is not None else None
kwargs = {}
if args.shape_gen_name is not None:
kwargs["shape_gen_name"] = args.shape_gen_name
if args.M is not None:
kwargs["M"] = args.M,
if args.K is not None:
kwargs["K"] = args.K,
if args.N is not None:
kwargs["N"] = args.N,
if args.scaling_type_input is not None:
kwargs["scaling_type_input"] = args.scaling_type_input
if args.scaling_type_weight is not None:
Expand Down
63 changes: 2 additions & 61 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch.nn as nn
import torch.utils.benchmark as benchmark

from utils import get_name_to_shapes_iter

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

Expand Down Expand Up @@ -48,67 +50,6 @@ def do_benchmarks(tops, peak_tops, f, *args, **kwargs):
return time_sec, tops_sec, pct_top_peak


def get_name_to_shapes_iter(
shape_gen_name: str,
M: Optional[int],
K: Optional[int],
N: Optional[int],
):
if shape_gen_name == 'llama':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
bsz, seq_len = 4, 4096
M = bsz * seq_len
# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (M, 8192, 1280),
"attn.w0": (M, 1024, 8192),
"ffn.w13": (M, 8192, 7168),
"ffn.w2": (M, 3584, 8192),
}
return name_to_shapes_70b.items()

elif shape_gen_name == 'square':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
name_to_shapes = {}
min_power_of_2 = 5 # 32
max_power_of_2 = 16 # 65,536
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
val = 2 ** power_of_2
name_to_shapes[idx] = val, val, val
return name_to_shapes.items()

elif shape_gen_name == 'sweep':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
name_to_shapes = {}
min_p2 = 5 # 32
max_p2 = 16 # 65,536
counter = 0
for M_p2 in range(min_p2, max_p2 + 1):
M = 2 ** M_p2
for K_p2 in range(min_p2, max_p2 + 1):
K = 2 ** K_p2
for N_p2 in range(min_p2, max_p2 + 1):
N = 2 ** N_p2
name_to_shapes[counter] = M, K, N
counter += 1
return name_to_shapes.items()

elif shape_gen_name == 'custom':
assert M is not None and K is not None and N is not None, \
'M, K, N must be specified for custom shape_gen'
name_to_shapes = {
1: (M, K, N),
}
return name_to_shapes.items()

raise AssertionError(f'unknown shape_gen_name {shape_gen_name}')


@torch.inference_mode()
def run(
n_limit: Optional[int] = None,
Expand Down
14 changes: 12 additions & 2 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def main(
model_type: str = "linear",
dtype_filter: str = "both",
):
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")

scaling_type_input = ScalingType(scaling_type_input)
Expand Down Expand Up @@ -250,8 +250,18 @@ def main(
input_tensor = torch.randn(
1, 8192, 4096, device=device, dtype=ref_dtype
).requires_grad_()
elif model_type == "norm_ffn_norm_small":
m_ref = NormFFNResidualNorm(
dim=4096,
hidden_dim=4096,
multiple_of=1024,
ffn_dim_multiplier=1.0,
)
input_tensor = torch.randn(
1, 2048, 4096, device=device, dtype=ref_dtype
).requires_grad_()
else:
M, K, N = 4 * 4096, 8192, 7168
M, K, N = 4096, 4096, 4096
m_ref = torch.nn.Sequential(
torch.nn.Linear(K, N, bias=False),
)
Expand Down
62 changes: 62 additions & 0 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import collections
import re
from typing import Optional


def profiler_output_to_time_by_kernel_name(prof):
Expand Down Expand Up @@ -81,3 +82,64 @@ def parse_bw_and_kernel_name(line):
return result.group(1), result.group(2)
else:
return None, None


def get_name_to_shapes_iter(
shape_gen_name: str,
M: Optional[int],
K: Optional[int],
N: Optional[int],
):
if shape_gen_name == 'llama':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
bsz, seq_len = 4, 4096
M = bsz * seq_len
# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (M, 8192, 1280),
"attn.w0": (M, 1024, 8192),
"ffn.w13": (M, 8192, 7168),
"ffn.w2": (M, 3584, 8192),
}
return name_to_shapes_70b.items()

elif shape_gen_name == 'square':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
name_to_shapes = {}
min_power_of_2 = 5 # 32
max_power_of_2 = 16 # 65,536
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
val = 2 ** power_of_2
name_to_shapes[idx] = val, val, val
return name_to_shapes.items()

elif shape_gen_name == 'sweep':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
name_to_shapes = {}
min_p2 = 5 # 32
max_p2 = 16 # 65,536
counter = 0
for M_p2 in range(min_p2, max_p2 + 1):
M = 2 ** M_p2
for K_p2 in range(min_p2, max_p2 + 1):
K = 2 ** K_p2
for N_p2 in range(min_p2, max_p2 + 1):
N = 2 ** N_p2
name_to_shapes[counter] = M, K, N
counter += 1
return name_to_shapes.items()

elif shape_gen_name == 'custom':
assert M is not None and K is not None and N is not None, \
'M, K, N must be specified for custom shape_gen'
name_to_shapes = {
1: (M, K, N),
}
return name_to_shapes.items()

raise AssertionError(f'unknown shape_gen_name {shape_gen_name}')
2 changes: 1 addition & 1 deletion test/float8/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ IS_ROCM=$(rocm-smi --version || true)

pytest test/float8/test_base.py
pytest test/float8/test_compile.py
pytest test/float8/test_inference_flows.py
# pytest test/float8/test_inference_flows.py
pytest test/float8/test_numerics_integration.py

# These tests do not work on ROCm yet
Expand Down
3 changes: 1 addition & 2 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def amax_to_scale(
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
else:
Expand All @@ -53,7 +52,7 @@ def amax_to_scale(
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
scale.copy_(res)
return res.float()
return scale


Expand Down
Loading