Skip to content

Add autotuning for range() unrolling #219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ def visit_For(self, node: ast.For) -> None:
)
outputs: LiftTensorArgs | None = None
begin, end = self._extract_tile_begin_end(node)
if (begin is None or isinstance(begin, int)) and isinstance(end, int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this coming from a Triton requirement? Couldn't re unroll dynamic loops too?

CompileEnvironment.current().config_spec.allow_unroll_loops = True
if isinstance(inner_type, SequenceType):
iter_vars = inner_type.unpack()
if begin is None:
Expand Down
15 changes: 12 additions & 3 deletions helion/_compiler/generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,16 @@ def mask_var(self, block_idx: int) -> str | None:
return loops[-1].strategy.mask_var(block_idx)
return None

def add_statement(self, stmt: ast.AST | str | None) -> None:
def add_statement(self, stmt: ast.AST | list[ast.AST] | str | None) -> None:
if stmt is None:
return
if isinstance(stmt, str):
stmt = statement_from_string(stmt)
self.statements_stack[-1].append(stmt)
if isinstance(stmt, list):
for s in stmt:
self.statements_stack[-1].append(s)
else:
self.statements_stack[-1].append(stmt)

def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str:
return self.device_function.unique_name(prefix, dce=dce)
Expand Down Expand Up @@ -116,7 +120,12 @@ def add_device_loop(self, device_loop: DeviceLoopState) -> Iterator[None]:
for idx in device_loop.block_ids:
self.active_device_loops[idx].pop()
self.statements_stack[-1].extend(device_loop.outer_prefix)
self.add_statement(device_loop.for_node)
stmt = device_loop.for_node
if self.device_function.config.unroll_loops:
from .static_loop_unroller import unroll_loop

stmt = unroll_loop(node=device_loop.for_node, allow_range=True)
self.add_statement(stmt)
self.statements_stack[-1].extend(device_loop.outer_suffix)

def set_active_loops(self, device_grid: DeviceLoopOrGridState) -> None:
Expand Down
2 changes: 1 addition & 1 deletion helion/_compiler/host_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
from .static_loop_unroller import unroll_static_loops
from .type_propagation import propagate_types

unroll_static_loops(self)
unroll_static_loops(func=self, allow_range=False)
propagate_types(self, fake_args)
env.finalize_config_spec()
self.device_ir = lower_to_device_ir(self)
Expand Down
53 changes: 44 additions & 9 deletions helion/_compiler/static_loop_unroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class StaticLoopUnroller(ast.NodeTransformer):
TODO(oulgen): This pass is primitive, does not handle for.orelse, break, continue etc
"""

def __init__(self, allow_range: bool) -> None:
self.allow_range = allow_range

def visit_For(self, node: ast.For) -> ast.AST | list[ast.AST]:
# Generic visit to handle nested loops
node = self.generic_visit(node) # pyre-ignore[9]
Expand All @@ -45,6 +48,35 @@ def _extract_static_values(self, iter_node: ast.expr) -> list[ast.expr] | None:
"""
if isinstance(iter_node, (ast.List, ast.Tuple)):
return iter_node.elts
if (
self.allow_range
and isinstance(iter_node, ast.Call)
and isinstance(iter_node.func, ast.Name)
and iter_node.func.id == "range"
):
range_values = self._extract_range_values(iter_node)
if range_values is not None:
return [create(ast.Constant, value=val) for val in range_values]

return None

def _extract_range_values(self, range_call: ast.Call) -> list[int] | None:
"""
Extract values from a range() call if all arguments are constants.
"""
args = range_call.args

for arg in args:
if not isinstance(arg, ast.Constant) or not isinstance(arg.value, int):
return None

if len(args) == 1:
return list(range(args[0].value)) # pyre-ignore[16]
if len(args) == 2:
return list(range(args[0].value, args[1].value))
if len(args) == 3:
return list(range(args[0].value, args[1].value, args[2].value))

return None

def _unroll_loop(
Expand All @@ -68,14 +100,17 @@ def _unroll_loop(
return unrolled_statements


def unroll_static_loops(func: HostFunction) -> None:
new_body = []
def unroll_loop(*, node: ast.AST, allow_range: bool) -> ast.AST | list[ast.AST]:
try:
return StaticLoopUnroller(allow_range).visit(node)
except CannotUnrollLoop:
return node


def unroll_static_loops(*, func: HostFunction, allow_range: bool) -> None:
new_body: list[ast.stmt] = []
for stmt in func.body:
try:
unrolled_stmts = StaticLoopUnroller().visit(stmt)
except CannotUnrollLoop:
new_body.append(stmt)
else:
assert isinstance(unrolled_stmts, ast.stmt)
new_body.append(unrolled_stmts)
maybe_unrolled = unroll_loop(node=stmt, allow_range=allow_range)
assert isinstance(maybe_unrolled, ast.stmt)
new_body.append(maybe_unrolled)
func.body = new_body
7 changes: 7 additions & 0 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"num_warps",
"num_stages",
"use_yz_grid",
"unroll_loops",
"indexing",
]
)
Expand All @@ -65,6 +66,7 @@ class ConfigSpec:
default_factory=dict
)
allow_use_yz_grid: bool | None = None
allow_unroll_loops: bool | None = None

def _remove_duplicates(self) -> None:
self.loop_orders._remove_duplicates()
Expand Down Expand Up @@ -111,6 +113,8 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:

if self.allow_use_yz_grid:
config.setdefault("use_yz_grid", False)
if self.allow_unroll_loops:
config.setdefault("unroll_loops", False)

config.setdefault("indexing", "pointer")

Expand Down Expand Up @@ -151,6 +155,9 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
not config["flatten_loops"] or not config["flatten_loops"][0]
):
config["use_yz_grid"] = use_yz_grid
if self.allow_unroll_loops:
config["unroll_loops"] = fn(BooleanFragment())

for name in ("loop_orders", "flatten_loops", "reduction_loops", "l2_groupings"):
if not config[name]:
config.pop(name)
Expand Down
7 changes: 7 additions & 0 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
num_warps: int | None = None,
num_stages: int | None = None,
use_yz_grid: bool | None = None,
unroll_loops: bool | None = None,
indexing: IndexingLiteral | None = None,
# For user-defined properties
**kwargs: object,
Expand All @@ -43,6 +44,7 @@ def __init__(
num_warps: Number of warps per block.
num_stages: Number of stages for software pipelining.
use_yz_grid: Whether to use yz grid dimensions.
unroll_loops: Whether to unroll loops.
indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr").
**kwargs: Additional user-defined configuration parameters.
"""
Expand All @@ -57,6 +59,7 @@ def __init__(
"num_stages": num_stages,
"indexing": indexing,
"use_yz_grid": use_yz_grid,
"unroll_loops": unroll_loops,
}
for key, value in core_props.items():
if value is not None:
Expand Down Expand Up @@ -138,6 +141,10 @@ def l2_groupings(self) -> list[int]:
def use_yz_grid(self) -> bool:
return cast("bool", self.config.get("use_yz_grid", False))

@property
def unroll_loops(self) -> bool:
return cast("bool", self.config.get("unroll_loops", False))

@property
def indexing(self) -> IndexingLiteral:
return self.config.get("indexing", "pointer") # type: ignore
Expand Down
48 changes: 38 additions & 10 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def test_config_fragment0(self):
self.assertExpectedInline(
"\n".join(map(repr, configs)),
"""\
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=3, indexing='pointer')
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[8], num_warps=32, num_stages=3, indexing='block_ptr')
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[1, 0]], l2_groupings=[32], num_warps=8, num_stages=8, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 32], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], num_warps=8, num_stages=2, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[64], num_warps=4, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[32, 128, 64], loop_orders=[[0, 1]], l2_groupings=[2], num_warps=16, num_stages=5, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], num_warps=16, num_stages=3, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=2, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=1, num_stages=1, indexing='tensor_descriptor')""",
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=3, indexing='pointer', unroll_loops=False)
helion.Config(block_sizes=[32, 128, 64], loop_orders=[[1, 0]], l2_groupings=[8], num_warps=32, num_stages=3, indexing='block_ptr', unroll_loops=False)
helion.Config(block_sizes=[128, 16, 128], loop_orders=[[0, 1]], l2_groupings=[8], num_warps=4, num_stages=6, indexing='pointer', unroll_loops=False)
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=7, indexing='tensor_descriptor', unroll_loops=True)
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[8], num_warps=32, num_stages=2, indexing='tensor_descriptor', unroll_loops=False)
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[64], num_warps=4, num_stages=7, indexing='tensor_descriptor', unroll_loops=False)
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=4, indexing='tensor_descriptor', unroll_loops=False)
helion.Config(block_sizes=[64, 16, 128], loop_orders=[[0, 1]], l2_groupings=[4], num_warps=32, num_stages=2, indexing='tensor_descriptor', unroll_loops=False)
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=16, num_stages=3, indexing='block_ptr', unroll_loops=False)
helion.Config(block_sizes=[16, 32, 32], loop_orders=[[0, 1]], l2_groupings=[4], num_warps=4, num_stages=7, indexing='block_ptr', unroll_loops=True)""",
)

@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
Expand Down Expand Up @@ -187,6 +187,34 @@ def test_differential_evolution_search(self):
fn = bound_kernel.compile_config(best)
torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1)

def test_loop_unroll(self):
@helion.kernel()
def fn(x: torch.Tensor) -> torch.Tensor:
out = torch.zeros_like(x)
for tile in hl.tile(x.size()):
out[tile] = x[tile]
for i in range(1, 4):
out[tile] += i
return out

args = (torch.randn(4, device=DEVICE),)
spec = fn.bind(args).config_spec
configs = ConfigGeneration(spec).random_population(10)
self.assertExpectedInline(
"\n".join(map(repr, configs)),
"""\
helion.Config(block_sizes=[4], num_warps=4, num_stages=3, indexing='pointer', unroll_loops=False)
helion.Config(block_sizes=[2], num_warps=32, num_stages=5, indexing='block_ptr', unroll_loops=True)
helion.Config(block_sizes=[1], num_warps=4, num_stages=4, indexing='block_ptr', unroll_loops=False)
helion.Config(block_sizes=[4], num_warps=2, num_stages=8, indexing='block_ptr', unroll_loops=True)
helion.Config(block_sizes=[1], num_warps=2, num_stages=3, indexing='block_ptr', unroll_loops=True)
helion.Config(block_sizes=[4], num_warps=8, num_stages=5, indexing='pointer', unroll_loops=False)
helion.Config(block_sizes=[4], num_warps=2, num_stages=5, indexing='block_ptr', unroll_loops=False)
helion.Config(block_sizes=[1], num_warps=1, num_stages=4, indexing='block_ptr', unroll_loops=True)
helion.Config(block_sizes=[2], num_warps=32, num_stages=7, indexing='pointer', unroll_loops=True)
helion.Config(block_sizes=[2], num_warps=2, num_stages=2, indexing='pointer', unroll_loops=True)""",
)

def test_use_default_config(self):
@helion.kernel(use_default_config=True)
def add(a, b):
Expand Down
97 changes: 97 additions & 0 deletions test/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,103 @@ def _fn_make_precompiler(x: torch.Tensor):
return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
)

def test_loop_unroll3(self):
@helion.kernel()
def fn(x: torch.Tensor) -> torch.Tensor:
out = torch.zeros_like(x)
for tile in hl.tile(x.size()):
out[tile] = x[tile]
for i in range(1, 4):
out[tile] += i
return out

x = torch.randn(4, device=DEVICE)
code, output = code_and_output(fn, (x,), block_sizes=[4], unroll_loops=True)
torch.testing.assert_close(output, x + 6)
self.assertExpectedInline(
code,
"""\
from __future__ import annotations

import torch
import triton
import triton.language as tl

@triton.jit
def _fn_kernel(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
tl.store(out + indices_0 * out_stride_0, load, mask_0)
offset_1 = 1
load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
v_0 = offset_1.to(tl.float32)
v_1 = load_1 + v_0
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
offset_1 = 2
load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
v_0 = offset_1.to(tl.float32)
v_1 = load_1 + v_0
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
offset_1 = 3
load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
v_0 = offset_1.to(tl.float32)
v_1 = load_1 + v_0
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)

def fn(x: torch.Tensor):
out = torch.zeros_like(x)
_BLOCK_SIZE_0 = 4
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

def _fn_make_precompiler(x: torch.Tensor):
out = torch.zeros_like(x)
_BLOCK_SIZE_0 = 4
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
)

code, output = code_and_output(fn, (x,), block_sizes=[4], unroll_loops=False)
torch.testing.assert_close(output, x + 6)
self.assertExpectedInline(
code,
"""\
from __future__ import annotations

import torch
import triton
import triton.language as tl

@triton.jit
def _fn_kernel(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
tl.store(out + indices_0 * out_stride_0, load, mask_0)
for offset_1 in range(1, 4, 1):
load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
v_0 = offset_1.to(tl.float32)
v_1 = load_1 + v_0
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)

def fn(x: torch.Tensor):
out = torch.zeros_like(x)
_BLOCK_SIZE_0 = 4
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

def _fn_make_precompiler(x: torch.Tensor):
out = torch.zeros_like(x)
_BLOCK_SIZE_0 = 4
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
)


if __name__ == "__main__":
unittest.main()
Loading