Skip to content

Commit

Permalink
Add TK Wave kernels to conv benchmark (#35)
Browse files Browse the repository at this point in the history
* Add option to test TKW-based conv kernels to convbench.
* Only limited subset of datatypes is supported for now (only
`f16xf16xf32`)
* Need latest iree-turbine main

---------

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Dec 19, 2024
1 parent c3bdf8e commit 19c832f
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 7 deletions.
11 changes: 9 additions & 2 deletions .github/workflows/run_bench.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,16 @@ jobs:
source bench_venv/bin/activate
python convbench/conv_bench.py
- name: TK Convolutions
run: |
source bench_venv/bin/activate
python convbench/conv_bench.py --tk
- name: Attention
run: |
source bench_venv/bin/activate
python attentionbench/attention_bench.py
- name: TK GEMM
run: |
source bench_venv/bin/activate
Expand All @@ -57,11 +62,13 @@ jobs:
source bench_venv/bin/activate
python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_i8.png --dtype i8
python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_f16.png --dtype f16
python convbench/conv_bench.py --roofline results/iree_conv_tk.csv --plot results/iree_conv_tk_i8.png --dtype i8
python convbench/conv_bench.py --roofline results/iree_conv_tk.csv --plot results/iree_conv_tk_f16.png --dtype f16
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ
python convbench/conv_bench.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png
python convbench/conv_bench.py --roofline results/iree_gemm_tk.csv --plot results/iree_gemm_tk.png
python convbench/conv_bench.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv --plot results/combined.png
python convbench/conv_bench.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv,results/iree_conv_tk.csv --plot results/combined.png
- name: Upload benchmark results
uses: actions/upload-artifact@v4
Expand Down
19 changes: 15 additions & 4 deletions convbench/conv_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
from conv_utils import *
from problems import get_conv_configs, get_conv_test_configs

from wave_conv_utils import compile_wave_conv_config

def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
def compile_conv_iree(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
mlir_file, vmfb_file, dump_path = compile_conv_config(config, kernel_dir, vmfb_dir, extra_compiler_args)
return (tag, config, mlir_file, vmfb_file, dump_path)

def compile_conv_wave(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
mlir_file, vmfb_file, dump_path = compile_wave_conv_config(config, kernel_dir, vmfb_dir, extra_compiler_args)
return (tag, config, mlir_file, vmfb_file, dump_path)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Config file updater.")
Expand All @@ -42,6 +46,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
parser.add_argument("--batch", help="roofline on certain batch", type=int, default=None)
parser.add_argument("--dtype", help="roofline on certain dtype", default=None)
parser.add_argument("--model", help="roofline on certain model", default=None)
parser.add_argument('--tk', help="Run conv kernels using Turbine Kernels", action=argparse.BooleanOptionalAction)

args = parser.parse_args()
logging.basicConfig(level=args.log_level)
Expand Down Expand Up @@ -71,6 +76,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
compile_args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, extra_compiler_args), configs
)
compile_conv = compile_conv_wave if args.tk else compile_conv_iree
with Pool(num_cpus) as pool:
compilation_results = list(tqdm(pool.starmap(compile_conv, list(compile_args))))

Expand All @@ -88,7 +94,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):

results = []
index = 0
output_csv = "results/iree_conv.csv"
output_csv = "results/iree_conv_tk.csv" if args.tk else "results/iree_conv.csv"
entrypoint = "isolated_benchmark" if args.tk else "main"
csv_dir = os.path.dirname(output_csv)
if not os.path.exists(csv_dir):
os.makedirs(csv_dir)
Expand All @@ -105,12 +112,16 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
f"--device={device}",
"--device_allocator=caching",
f"--module={vmfb_filename}",
"--function=main",
f"--function={entrypoint}",
"--benchmark_repetitions=3",
f"--input={image_shape}",
f"--input={filter_shape}",
"--benchmark_repetitions=3",
]

if args.tk:
out_shape = config.get_out_shape()
exec_args.append(f"--input={out_shape}")

print(f"Running {vmfb_filename}...")
# iree benchmark kernels
ret_value, cmd_out, cmd_stderr = run_iree_command(exec_args)
Expand Down
14 changes: 13 additions & 1 deletion convbench/conv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,19 @@ def get_kernel_shape(self) -> str:
return str(self.P) + "x" + str(self.Q) + "x" + str(self.C) + "x" + str(self.F) + "x" + self.input_dtype
if "nchw" in self.OP:
return str(self.F) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype


def get_out_shape(self) -> str:
padding = 0
in_h = self.H * self.S + self.P - 1
in_w = self.W * self.S + self.Q - 1
h_out = (in_h + 2 * padding - self.P) // self.S + 1
w_out = (in_w + 2 * padding - self.Q) // self.S + 1
n = self.N
nf = self.F
if "nhwc" in self.OP:
return str(n) + "x" + str(h_out) + "x" + str(w_out) + "x" + str(nf) + "x" + self.output_dtype
if "nchw" in self.OP:
return str(n) + "x" + str(nf) + "x" + str(h_out) + "x" + str(w_out) + "x" + self.output_dtype

def get_byte_count(self) -> int:
dtype_bits_map = {
Expand Down
109 changes: 109 additions & 0 deletions convbench/wave_conv_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from utils import *
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from conv_utils import ConvConfig
import traceback

try:
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
from iree.turbine.kernel.wave.templates.conv import get_igemm_conv2d
from iree.turbine.kernel.wave.utils import (
get_default_arch,
get_default_run_config,
get_default_compile_config,
device_randn,
device_randint,
device_randperm,
device_zeros,
)
except ImportError:
TURBINE_AVAILABLE = False
else:
TURBINE_AVAILABLE = True


def compile_wave_conv_config(
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path, extra_compiler_args: list[str]
) -> tuple[Path, Optional[Path]]:
if not TURBINE_AVAILABLE:
raise ValueError("iree.turbine package is not available")

mlir_file = kernel_dir / (config.get_name() + ".mlir")
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")
files_path = vmfb_dir / config.get_name()

try:
_compile_conv(config, mlir_file, vmfb_file)
except Exception as e:
error_file = vmfb_dir / (config.get_name() + "_error.txt")
print(f"Failed to compile {config.get_name()}. Error dumped in {error_file}")
with open(error_file, "w") as f:
f.write(str(e))
f.write(traceback.format_exc())
return mlir_file, None, None

return mlir_file, vmfb_file, files_path


def _decode_op(op: str) -> tuple[str, str]:
if op.startswith("conv_2d_"):
return "conv_2d", op[len("conv_2d_") :]

raise ValueError(f"Unsupported op: {op}")


def _convert_dtype(dtype: str):
dtypes = {
"i8": tkl.i8,
"i16": tkl.i16,
"i32": tkl.i32,
"i64": tkl.i64,
"f16": tkl.f16,
"f32": tkl.f32,
"f64": tkl.f64,
"bf16": tkl.bf16,
}
return dtypes[dtype]


def _compile_conv(config: ConvConfig, mlir_file: Path, vmfb_file: Path):
print("Compile TKW kernel", config.OP)
op_type, layout = _decode_op(config.OP)

in_h = config.H * config.S + config.P - 1
in_w = config.W * config.S + config.Q - 1
if op_type == "conv_2d":
conv, hyperparams = get_igemm_conv2d(
layout=layout,
n=config.N,
h=in_h,
w=in_w,
c=config.C,
hf=config.P,
wf=config.Q,
nf=config.F,
stride=config.S,
input_dtype=_convert_dtype(config.input_dtype),
output_dtype=_convert_dtype(config.output_dtype),
)
else:
raise ValueError(f"Unsupported op_type: {op_type}")

# config = get_default_run_config()
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

with tk.gen.TestLaunchContext(
hyperparams,
canonicalize=True,
create_vmfb_file=vmfb_file,
run_config=config,
schedule=False,
inline=False,
):
mod = conv().module_op # This will generate vmfb file
with open(mlir_file, "w") as f:
f.write(str(mod))

print(f"Successfully compiled to {vmfb_file}")

0 comments on commit 19c832f

Please sign in to comment.