Skip to content

Commit

Permalink
[ROOFLINE] Roofline analysis over RPC (apache#11252)
Browse files Browse the repository at this point in the history
* [ROOFLINE] Roofline analysis over RPC

Run roofline analysis on remote devices if requested. Peak flops and
peak bandwidth estimation are done on the remote device.

* allocate testing arrays directly on device and randomly fill

* forgot to include remote

* lower flops ratio, machine may be using multiple threads

* forgot fill
  • Loading branch information
Tristan Konolige authored and mehrdadh committed May 16, 2022
1 parent 1c42f85 commit 655fc38
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 18 deletions.
107 changes: 91 additions & 16 deletions python/tvm/utils/roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,32 @@
from typing import Dict, Union, Optional
import numpy as np

from .. import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform
from .. import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform, get_global_func
from ..target import Target
from ..runtime import profiler_vm, profiling, Device, num_threads
from ..script import tir as T
from ..ir.instrument import pass_instrument
from ..ir.expr import GlobalVar
from ..rpc.base import RPC_SESS_MASK
from ..rpc.client import RPCSession
from ..contrib import utils


def _create_args(mod: IRModule, dev: Device, func_name: str = "main"):
def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=None):
if dev.device_type >= RPC_SESS_MASK:
random_fill = remote.get_function("tvm.contrib.random.random_fill")
else:
random_fill = get_global_func("tvm.contrib.random.random_fill")
assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
args = []
for arg in mod[func_name].params:
args.append(
nd.array(
np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
device=dev,
)
ary = nd.empty(
[x.value for x in arg.type_annotation.shape],
arg.type_annotation.dtype,
device=dev,
)
random_fill(ary)
args.append(ary)
return args


Expand Down Expand Up @@ -103,6 +112,7 @@ def estimate_peak_fma_flops(
dev: Device,
vec_width: Optional[int] = None,
num_vector_registers: Optional[int] = None,
remote: Optional[RPCSession] = None,
) -> float:
"""
Estimate the maximum number of FLOP/s this target/device combo is capable
Expand All @@ -123,6 +133,9 @@ def estimate_peak_fma_flops(
num_vector_registers : Optional[int]
Number of vector registers on the underlying hardware. Will try to
infer if no value is provided.
remote : Optional[RPCSession]
Remote session used to upload artifacts for runtime evaluation. Must be
the same session used to create `dev`.
Returns
-------
Expand All @@ -146,7 +159,23 @@ def estimate_peak_fma_flops(
)
with transform.PassContext(opt_level=3):
f = build(specialized, target=target)
a = nd.array(np.ones((nthreads, num_vector_registers, vec_width), dtype="float32"), device=dev)

# upload to remote if running over rpc
if dev.device_type >= RPC_SESS_MASK:
if remote is None:
raise RuntimeError("A RPCSession must be provided when using a remote device.")
temp = utils.tempdir()
path = temp.relpath("peak_fma_flops.tar")
f.export_library(path)
remote.upload(path)
f = remote.load_module("peak_fma_flops.tar")
random_fill = remote.get_function("tvm.contrib.random.random_fill")
else:
random_fill = get_global_func("tvm.contrib.random.random_fill")
assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"

a = nd.empty((nthreads, num_vector_registers, vec_width), dtype="float32", device=dev)
random_fill(a)
times = f.time_evaluator(f.entry_name, dev, repeat=100, number=1)(a)
flops = 2 * vec_width * num_vector_registers * nthreads * iters # fma is two flops
flop_s = flops / times.min
Expand All @@ -171,7 +200,12 @@ def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T.
B[i, l, j] += A[i, k, l, j]


def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
def estimate_peak_bandwidth(
target: Target,
dev: Device,
vec_width: Optional[int] = None,
remote: Optional[RPCSession] = None,
) -> float:
"""Estimate peak memory bandwidth of a target/device combo.
Peak bandwidth is estimated by running a small experiment on the underlying
Expand All @@ -187,6 +221,9 @@ def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int
Device to measure peak bandwidth on.
vec_width : Optional[int]
Vector unit width, determined from target if not supplied.
remote : Optional[RPCSession]
Remote session used to upload artifacts for runtime evaluation. Must be
the same session used to create `dev`.
Returns
-------
Expand All @@ -207,13 +244,30 @@ def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int
)
with transform.PassContext(opt_level=3):
f = build(specialized, target=target)

# upload to remote if running over rpc
if dev.device_type >= RPC_SESS_MASK:
if remote is None:
raise RuntimeError("A RPCSession must be provided when using a remote device.")
temp = utils.tempdir()
path = temp.relpath("peak_bandwidth.tar")
f.export_library(path)
remote.upload(path)
f = remote.load_module("peak_bandwidth.tar")
random_fill = remote.get_function("tvm.contrib.random.random_fill")
else:
random_fill = get_global_func("tvm.contrib.random.random_fill")
assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"

threads = num_threads()
# Data size needs to be larger than last level of cache. We don't have a
# way of getting cache sizes, so this number should give us a large enough
# size.
size = 10**8 // (4 * threads * vec_width)
a = nd.array(np.ones((threads, size, 4, vec_width), dtype="float32"), device=dev)
b = nd.array(np.ones((threads, vec_width, 4), dtype="float32"), device=dev)
a = nd.empty((threads, size, 4, vec_width), dtype="float32", device=dev)
random_fill(a)
b = nd.empty((threads, vec_width, 4), dtype="float32", device=dev)
random_fill(b)
times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b, threads)
return a.numpy().size * 4 / times.min # 4 bytes per float32

Expand Down Expand Up @@ -241,6 +295,7 @@ def roofline_from_existing(
tir_functions: Dict[GlobalVar, tir.PrimFunc],
target: Target,
dev: Device,
remote: Optional[RPCSession] = None,
) -> profiling.Report:
"""Add roofline and other estimated statistics to an existing profiling report.
Expand Down Expand Up @@ -290,6 +345,9 @@ def roofline_from_existing(
TVM target that `report` was generated with.
dev : Device
Device that `report` was generated with.
remote : Optional[RPCSession]
Remote session used to upload artifacts for runtime evaluation. Must be
the same session used to create `dev`.
Returns
-------
Expand All @@ -299,8 +357,8 @@ def roofline_from_existing(
:py:func:`roofline_analysis` for more information on which metrics
are included.
"""
peak_bandwidth = estimate_peak_bandwidth(target, dev)
peak_flops = estimate_peak_fma_flops(target, dev)
peak_bandwidth = estimate_peak_bandwidth(target, dev, remote=remote)
peak_flops = estimate_peak_fma_flops(target, dev, remote=remote)

ridge_point = peak_flops / peak_bandwidth

Expand Down Expand Up @@ -346,7 +404,11 @@ def roofline_from_existing(


def roofline_analysis(
mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], dev: Device
mod: IRModule,
params: Dict[str, nd.NDArray],
target: Union[str, Target],
dev: Device,
remote: Optional[RPCSession] = None,
) -> profiling.Report:
"""
Create a profiling report that contains roofline and other estimated
Expand Down Expand Up @@ -385,6 +447,10 @@ def roofline_analysis(
dev : Device
Device to run on.
remote : Optional[RPCSession]
Remote session used to upload artifacts for runtime evaluation. Must be
the same session used to create `dev`.
Returns
-------
Expand All @@ -405,9 +471,18 @@ def roofline_analysis(
config=pass_ctx.config,
):
lib = relay.vm.compile(mod, params=params, target=target)
# upload to remote if running over rpc
if dev.device_type >= RPC_SESS_MASK:
if remote is None:
raise RuntimeError("A RPCSession must be provided when using a remote device.")
temp = utils.tempdir()
path = temp.relpath("roofline_lib.tar")
lib.mod.export_library(path)
remote.upload(path)
lib = remote.load_module("roofline_lib.tar")
vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)

args = _create_args(mod, dev)
args = _create_args(mod, dev, remote=remote)
report = vmexec.profile(*args)

return roofline_from_existing(report, save_tir.functions, target, dev)
return roofline_from_existing(report, save_tir.functions, target, dev, remote=remote)
57 changes: 55 additions & 2 deletions tests/python/unittest/test_runtime_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,23 @@ def test_estimate_peak_fma_flops(target, dev):
flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev)
# Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu.
assert (
flops > 10**9 * tvm.runtime.num_threads() and flops < 10**14
), f"FLOP/s should be between 10^9 * num_threads and 10^14, but it is {flops}"
flops > 10**9 and flops < 10**14
), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}"


def test_estimate_peak_fma_flops_rpc():
target = "llvm -mattr=+fma,+avx2"
server = rpc.Server(key="profiling")
remote = rpc.connect("127.0.0.1", server.port, key="profiling")
dev = remote.device(target)
flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev, remote=remote)
# Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu.
assert (
flops > 10**9 and flops < 10**14
), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}"


@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
@tvm.testing.parametrize_targets("llvm")
def test_estimate_peak_bandwidth(target, dev):
# This test uses vectorized instructions so we need a target that supports them
Expand All @@ -284,6 +297,20 @@ def test_estimate_peak_bandwidth(target, dev):
), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"


@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
def test_estimate_peak_bandwidth_rpc():
target = "llvm -mattr=+fma,+avx2"
server = rpc.Server(key="profiling")
remote = rpc.connect("127.0.0.1", server.port, key="profiling")
dev = remote.device(target)
bandwidth = tvm.utils.estimate_peak_bandwidth(tvm.target.Target(target), dev, remote=remote)
# Assume we can achieve 1 GB/s. DDR2 should transfer somewhere around 6
# GB/s, so this should leave enough wiggle room.
assert (
bandwidth > 10**9 and bandwidth < 10**12
), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"


@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
@tvm.testing.parametrize_targets("llvm")
def test_roofline_analysis(target, dev):
Expand All @@ -304,6 +331,32 @@ def test_roofline_analysis(target, dev):
assert call["Percent of Theoretical Optimal"].ratio >= 0


@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
def test_roofline_analysis_rpc():
target = "llvm"

a = relay.var("a", relay.TensorType((512, 512), "float32"))
b = relay.var("b", relay.TensorType((512, 512), "float32"))
c = relay.nn.dense(a, b)
mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
params = {}

server = rpc.Server(key="profiling")
remote = rpc.connect("127.0.0.1", server.port, key="profiling")
dev = remote.device(target)

report = tvm.utils.roofline_analysis(mod, params, target, dev, remote=remote)

assert "Bound" in report.table()
assert "Percent of Theoretical Optimal" in report.table()
for call in report.calls:
if "Percent of Theoretical Optimal" in call:
# Ideally we'd like a little tighter bound here, but it is hard to
# know how well this dense will perform without tuning. And we
# don't have an operator that uses a specific number of flops.
assert call["Percent of Theoretical Optimal"].ratio >= 0


if __name__ == "__main__":
import sys
import pytest
Expand Down

0 comments on commit 655fc38

Please sign in to comment.