Skip to content

float8 profiling: save torch_logs, and attach them to profiling trace #645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 103 additions & 38 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@

import copy
import io
import os
import random
from contextlib import nullcontext, redirect_stdout
from dataclasses import dataclass, field
from pathlib import Path
import pathlib
from typing import Callable, Optional

import fire
import pandas as pd

# disable inductor FX cache, so we can can always study the inductor output logs
os.environ['TORCHINDUCTOR_FORCE_DISABLE_CACHES'] = '1'

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -30,6 +34,7 @@
parse_bw_and_kernel_name,
profiler_output_to_gpu_time_for_key,
profiler_output_to_filtered_time_by_kernel_name,
update_triton_kernels_in_prof_chome_trace_with_torch_logs,
)

# don't truncate long kernel names
Expand All @@ -38,6 +43,7 @@
pd.set_option("display.float_format", "{:.3f}".format)



class LNLinear(torch.nn.Module):
def __init__(self, fc_dim1, fc_dim2):
super().__init__()
Expand Down Expand Up @@ -151,7 +157,9 @@ def forward(self, h):

@dataclass
class ProfileConfig:
file_path: Optional[str] = None
trace_file_path: Optional[str] = None
logs_file_path: Optional[str] = None
trace_modified_file_path: Optional[str] = None
name: Optional[str] = None
cuda: bool = True
iters: int = 0
Expand All @@ -162,13 +170,33 @@ class ProfileConfig:


def profile_function(
config: ProfileConfig, func: Callable, *args, **kwargs
config: ProfileConfig,
func: Callable,
add_inductor_metadata_to_trace: bool,
*args,
**kwargs,
) -> torch.profiler.profile:
"""Profile a torch function and save the result to a file"""
seed = 123
random.seed(seed)
torch.manual_seed(seed)

if add_inductor_metadata_to_trace:
# ensure we aren't interfering with other torch_log settings
if os.environ.get('TORCH_LOGS', '') != '':
raise AssertionError('using TORCH_LOGS together with add_inductor_metadata_to_trace is not supported yet')

# save torch.compile logs to a file specific to this benchmark run
# TODO(future): can we hack torch.compile to print to file only and not stdout?
# or maybe just use tlparse?
torch._logging.set_logs(output_code=True)
# by default torch.compile appends to log_file_name, so we delete it
# if it exists
if os.path.isfile(config.logs_file_path):
pathlib.Path.unlink(config.logs_file_path)
torch._logging._init_logs(log_file_name=config.logs_file_path)


activities = [ProfilerActivity.CPU]
if config.cuda:
activities.append(ProfilerActivity.CUDA)
Expand All @@ -182,6 +210,10 @@ def profile_function(
nullcontext() if config.name is None else record_function(config.name)
)
profile_memory = config.memory_profile_path is not None

# warm up
func(*args, **kwargs)

with profile(
activities=activities,
profile_memory=profile_memory,
Expand All @@ -195,20 +227,35 @@ def profile_function(
if config.sync:
torch.cuda.synchronize()

if config.file_path is not None:
prof.export_chrome_trace(config.file_path)
if config.trace_file_path is not None:
prof.export_chrome_trace(config.trace_file_path)

if add_inductor_metadata_to_trace:
# modify the trace to have the triton kernel metadata and code
# visible inline
update_triton_kernels_in_prof_chome_trace_with_torch_logs(
config.trace_file_path,
config.logs_file_path,
config.trace_modified_file_path,
)

# undo custom log settings
torch._logging.set_logs(output_code=False)
torch._logging._init_logs(log_file_name=None)

return prof


def main(
profile_path_prefix: Path,
profile_path_prefix: pathlib.Path,
compile: bool = True,
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
model_type: str = "linear",
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
enable_sync_amax_history: bool = True,
):
assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")
Expand All @@ -220,6 +267,8 @@ def main(
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
enable_amax_init=False,
enable_pre_and_post_forward=False,
)
scaling_repr = "_".join(
[
Expand Down Expand Up @@ -290,7 +339,7 @@ def float8_forw_backward_wrapper(x):
# inspection of the fw+bw torch.compile without the scale
# syncing code
# TODO(future): make this better
if linear_requires_sync(config):
if linear_requires_sync(config) and enable_sync_amax_history:
with record_function("scale_amax_and_scales"):
sync_amax_history(m_float8)
out = float8_forw(x)
Expand All @@ -311,16 +360,14 @@ def float8_forw_backward_wrapper(x):

# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
# to populate triton kernel bandwidth further down in the script
f = io.StringIO()
if os.environ.get("TORCHINDUCTOR_PROFILE", "") != "":
context = nullcontext()
f = None
else:
f = io.StringIO()
context = redirect_stdout(f)
try:
with redirect_stdout(f):
# warm up
for _ in range(1):
if dtype_filter != "float8":
ref_forw_backward(input_tensor)
if dtype_filter != "bfloat16":
float8_forw_backward_wrapper(input_tensor)

with context:
profile_iters = 5
ref_times, float8_times = None, None
data = []
Expand All @@ -330,13 +377,19 @@ def float8_forw_backward_wrapper(x):
if dtype_filter != "float8":
# Profile Reference Model
print("profiling ref")
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
ref_path = profile_path_prefix + ref_suffix
ref_trace_suffix = f"_{model_type}_ref_compile_{compile}.json"
ref_logs_suffix = f"_{model_type}_ref_compile_{compile}.txt"
trace_ref_path = profile_path_prefix + ref_trace_suffix
log_ref_path = profile_path_prefix + ref_logs_suffix
trace_ref_modified_path = trace_ref_path.replace(".json", "_modified.json")
profile_config = ProfileConfig(
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
trace_ref_path, log_ref_path, trace_ref_modified_path, ref_trace_suffix, iters=profile_iters, warmup_iters=2, sync=True
)
p = profile_function(profile_config, ref_forw_backward, input_tensor)
print(f"saved {ref_path}")
p = profile_function(profile_config, ref_forw_backward, add_inductor_metadata_to_trace, input_tensor)
print(f"saved profiling trace to {trace_ref_path}")
if add_inductor_metadata_to_trace:
print(f"saved torch logs to {log_ref_path}")
print(f"saved modified trace to {trace_ref_modified_path}")
ref_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors)
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
for k, v in ref_times.items():
Expand All @@ -355,21 +408,31 @@ def float8_forw_backward_wrapper(x):
if dtype_filter != "bfloat16":
# Profile Float8 Model
print("profiling float8")
float8_suffix = (
float8_trace_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
)
float8_path = profile_path_prefix + float8_suffix
float8_log_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.txt"
)
trace_float8_path = profile_path_prefix + float8_trace_suffix
log_float8_path = profile_path_prefix + float8_log_suffix
trace_float8_modified_path = trace_float8_path.replace(".json", "_modified.json")
profile_config = ProfileConfig(
float8_path,
float8_suffix,
trace_float8_path,
log_float8_path,
trace_float8_modified_path,
float8_trace_suffix,
iters=profile_iters,
warmup_iters=2,
sync=True,
)
p = profile_function(
profile_config, float8_forw_backward_wrapper, input_tensor
profile_config, float8_forw_backward_wrapper, add_inductor_metadata_to_trace, input_tensor
)
print(f"saved {float8_path}")
print(f"saved profiling trace to {trace_float8_path}")
if add_inductor_metadata_to_trace:
print(f"saved torch logs to {log_float8_path}")
print(f"saved modified trace to {trace_float8_modified_path}")
float8_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors)
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
for k, v in float8_times.items():
Expand All @@ -393,17 +456,19 @@ def float8_forw_backward_wrapper(x):
print(f"Sync time ms: {sync_time_ms}")

finally:
# print the redirected stdout back to regular stdout
print(f.getvalue())

# populate the triton kernel bandwidth
for line in f.getvalue().split("\n"):
maybe_bw, maybe_kernel_name = parse_bw_and_kernel_name(line)
if maybe_kernel_name is not None:
# O(N) search, but it's ok since lists are small
for datum in data:
if datum[1] == maybe_kernel_name:
datum[-1] = maybe_bw
if f is not None:
# print the redirected stdout back to regular stdout
print(f.getvalue())

if os.environ.get("TORCHINDUCTOR_PROFILE", "") != "":
# populate the triton kernel bandwidth
for line in f.getvalue().split("\n"):
maybe_bw, maybe_kernel_name = parse_bw_and_kernel_name(line)
if maybe_kernel_name is not None:
# O(N) search, but it's ok since lists are small
for datum in data:
if datum[1] == maybe_kernel_name:
datum[-1] = maybe_bw

df = pd.DataFrame(
data,
Expand Down
Loading
Loading