Skip to content

Commit

Permalink
Support default_order=auto in kernel creation
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Jul 27, 2023
1 parent 022c8dc commit c37030a
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 88 deletions.
2 changes: 2 additions & 0 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,8 @@ def __init__(self, name, dtype=None, shape=None, dim_tags=None, offset=0,
n_axes=num_user_axes,
use_increasing_target_axes=self.max_target_axes > 1,
dim_names=dim_names)

if dim_tags is not None:
order = None

# }}}
Expand Down
34 changes: 30 additions & 4 deletions loopy/kernel/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@

from pymbolic.mapper import CSECachingMapperMixin
from pymbolic.primitives import Slice, Variable, Subscript, Call
from loopy.kernel.array import FixedStrideArrayDimTag
from loopy.tools import intern_frozenset_of_ids, Optional
from loopy.symbolic import (
IdentityMapper, WalkMapper, SubArrayRef)
from loopy.kernel.data import (
InstructionBase,
MultiAssignmentBase, Assignment,
SubstitutionRule, AddressSpace, ValueArg)
SubstitutionRule, AddressSpace, ValueArg, auto)
from loopy.translation_unit import for_each_kernel
from loopy.diagnostic import LoopyError, warn_with_kernel
import islpy as isl
Expand Down Expand Up @@ -1732,8 +1733,30 @@ def apply_default_order_to_args(kernel, default_order):

processed_args = []
for arg in kernel.args:
if isinstance(arg, ArrayBase) and arg.order is None:
arg = arg.copy(order=default_order)
if isinstance(arg, ArrayBase):
if default_order in ["c", "f", "C", "F"]:
if arg.dim_tags is None:
arg = arg.copy(order=default_order)
else:
# leave them the way they are
pass
elif default_order is auto:
if arg.dim_tags is None and arg.shape is not None:
assert arg.shape is not auto
arg = arg.copy(
dim_tags=tuple(
FixedStrideArrayDimTag(auto)
for i in range(len(arg.shape))))
arg = arg.copy(
dim_tags=tuple(
FixedStrideArrayDimTag(auto)
if isinstance(dim_tag, FixedStrideArrayDimTag)
else dim_tag
for dim_tag in arg.dim_tags))
else:
raise ValueError("unexpected value for default_order: "
f"'{default_order}'")

processed_args.append(arg)

return kernel.copy(args=processed_args)
Expand Down Expand Up @@ -2196,7 +2219,10 @@ def make_function(domains, instructions, kernel_data=None, **kwargs):
:arg preamble_generators: a list of functions of signature
(seen_dtypes, seen_functions) where seen_functions is a set of
(name, c_name, arg_dtypes), generating extra entries for *preambles*.
:arg default_order: "C" (default) or "F"
:arg default_order: "C" (default), "F" or :class:`loopy.auto`.
The default memory layout of arrays that are not explicitly
specified. If :class:`loopy.auto`, variables for strides are
automatically created.
:arg default_offset: 0 or :class:`loopy.auto`. The default value of
*offset* in :attr:`ArrayArg` for guessed arguments.
Defaults to 0.
Expand Down
11 changes: 11 additions & 0 deletions loopy/target/c/c_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ def _lpy_even_div(a, b):
# FIXME: This error message is kind of crummy.
raise ValueError("expected even division")
return result
def _lpy_even_div_none(a, b):
if a is None:
return None
result, remdr = divmod(a, b)
if remdr != 0:
# FIXME: This error message is kind of crummy.
raise ValueError("expected even division")
return result
"""


Expand Down
186 changes: 102 additions & 84 deletions loopy/target/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ class _ArgFindingEquation:
lhs: ExpressionT
rhs: ExpressionT

# Arg finding code is sorted by priority, lowest order first
# Arg finding code is sorted by priority, all equations (across all unknowns)
# of lowest priority first.
order: int

based_on_names: FrozenSet[str]
require_names: bool


class ExecutionWrapperGeneratorBase(ABC):
Expand Down Expand Up @@ -164,8 +164,6 @@ def generate_integer_arg_finding_from_array_data(

equations: List[_ArgFindingEquation] = []

from pymbolic.primitives import If

for arg_name in kai.passed_arg_names:
arg = kernel.arg_dict[arg_name]
assert arg.dtype is not None
Expand All @@ -179,10 +177,10 @@ def generate_integer_arg_finding_from_array_data(
lhs=var(arg.name).attr("shape").index(axis_nr),
rhs=shape_i,
order=0,
based_on_names=frozenset({arg.name}),
require_names=True))
based_on_names=frozenset({arg.name})))

for axis_nr, stride_i in enumerate(get_strides(arg)):
strides = get_strides(arg)
for axis_nr, stride_i in enumerate(strides):
if stride_i is not None:
equations.append(
_ArgFindingEquation(
Expand All @@ -192,43 +190,68 @@ def generate_integer_arg_finding_from_array_data(
rhs=_str_to_expr(stride_i),
order=0,
based_on_names=frozenset({arg.name}),
require_names=True))

if arg.offset is not None:
if not kernel.options.no_numpy:
offset = var("getattr")(var(arg.name), var('"offset"'), 0)
else:
offset = var(arg.name).attr("offset")
))

offset = If(var(f"{arg.name} is None"), 0, offset)
if not arg.is_input and isinstance(arg.shape, tuple):
# If no value was found by other means, provide
# C-contiguous default strides for output-only
# arguments.
equations.append(
_ArgFindingEquation(
lhs=(strides[axis_nr + 1]
* arg.shape[axis_nr + 1])
if axis_nr + 1 < len(strides)
else 1,
rhs=_str_to_expr(stride_i),
# Find strides from last dim to first,
# starting at order=1 so that shape
# parameters (found above) are
# available.
order=len(strides) - axis_nr,
based_on_names=frozenset(),
))

if arg.offset is not None:
equations.append(
_ArgFindingEquation(
lhs=var("_lpy_even_div")(
offset, arg.dtype.itemsize),
lhs=var("_lpy_even_div_none")(
var("getattr")(
var(arg.name), var('"offset"'), var("None")),
arg.dtype.itemsize),
rhs=_str_to_expr(arg.offset),
order=0,
based_on_names=frozenset([arg.name]),
))

# Argument finding from offsets should run last,
# as it assumes a zero offset if a variable is
# not passed. That should only be done if no
# other approach yielded a value for the variable.
# If no value was found by other means, default to zero.
equations.append(
_ArgFindingEquation(
lhs=0,
rhs=_str_to_expr(arg.offset),
order=1,
based_on_names=frozenset(arg.name),
require_names=False,
based_on_names=frozenset(),
))

# }}}

# {{{ regroup equations by unknown

unknown_to_equations: Dict[str, List[_ArgFindingEquation]] = {}
order_to_unknown_to_equations: \
Dict[int, Dict[str, List[_ArgFindingEquation]]] = {}

for eqn in equations:
deps = dep_map(eqn.rhs)

if len(deps) == 1:
unknown_var, = deps
unknown_to_equations.setdefault(unknown_var.name, []).append((eqn))
order_to_unknown_to_equations \
.setdefault(eqn.order, {}) \
.setdefault(unknown_var.name, []) \
.append((eqn))
else:
# Zero deps: nothing to determine, forget about it.
# 2+ deps: not implemented
pass

del equations

Expand All @@ -243,72 +266,67 @@ def generate_integer_arg_finding_from_array_data(
gen("# {{{ find integer arguments from array data")
gen("")

for unknown_name in sorted(unknown_to_equations):
unk_equations = sorted(unknown_to_equations[unknown_name],
key=lambda eqn: eqn.order)
req_subgen = CodeGenerator()
not_req_subgen = CodeGenerator()
for order_value in sorted(order_to_unknown_to_equations):
for unknown_name in sorted(order_to_unknown_to_equations[order_value]):
unk_equations = sorted(
order_to_unknown_to_equations[order_value][unknown_name],
key=lambda eqn: eqn.order)
subgen = CodeGenerator()

seen_based_on_names: Set[FrozenSet[str]] = set()
seen_based_on_names: Set[FrozenSet[str]] = set()

if_or_elif = "if"
if_or_elif = "if"

for eqn in unk_equations:
try:
# overkill :)
value_expr = solve_affine_equations_for(
[unknown_name],
[(eqn.lhs, eqn.rhs)]
)[Variable(unknown_name)]
except Exception as e:
# went wrong? oh well
from warnings import warn
warn("Unable to generate code to automatically "
f"find '{unknown_name}' "
f"from '{', '.join(eqn.based_on_names)}':\n"
f"{e}", ParameterFinderWarning)
continue

# Do not use more than one bit of data from each of the
# 'based_on_names' to find each value, i.e. if a value can be
# found via shape and strides, only one of them suffices.
# This also helps because strides can be unreliable in the
# face of zero-length axes.
if eqn.based_on_names in seen_based_on_names:
continue
seen_based_on_names.add(eqn.based_on_names)

if eqn.require_names:
condition = " and ".join(
f"{ary_name} is not None"
for ary_name in eqn.based_on_names)
req_subgen(f"{if_or_elif} {condition}:")
with Indentation(req_subgen):
req_subgen(
for eqn in unk_equations:
if eqn.rhs == Variable(unknown_name):
# Some of the expressions above are non-affine. Let's not
# get carried away by trying to solve a much more complex
# problem than needed.
value_expr = eqn.lhs
else:
try:
# overkill :)
value_expr = solve_affine_equations_for(
[unknown_name],
[(eqn.lhs, eqn.rhs)]
)[Variable(unknown_name)]
except Exception as e:
# went wrong? oh well
from warnings import warn
warn("Unable to generate code to automatically "
f"find '{unknown_name}' "
f"from '{', '.join(eqn.based_on_names)}':\n"
f"{e}", ParameterFinderWarning)
continue

# Do not use more than one bit of data from each of the
# 'based_on_names' to find each value, i.e. if a value can be
# found via shape and strides, only one of them suffices.
# This also helps because strides can be unreliable in the
# face of zero-length axes.
if eqn.based_on_names in seen_based_on_names:
continue
seen_based_on_names.add(eqn.based_on_names)

if eqn.based_on_names:
condition = " and ".join(
f"{ary_name} is not None"
for ary_name in eqn.based_on_names)
else:
condition = "True"

subgen(f"{if_or_elif} {condition}:")
with Indentation(subgen):
subgen(
f"{unknown_name} = {StringifyMapper()(value_expr)}")
if_or_elif = "elif"

req_subgen("")
else:
not_req_subgen(
f"{unknown_name} = {StringifyMapper()(value_expr)}")

not_req_subgen("")

if not_req_subgen.code:
gen(f"if {unknown_name} is None:")
with Indentation(gen):
gen.extend(not_req_subgen)
subgen("")

if req_subgen.code:
# still? try the req_subgen
gen(f"if {unknown_name} is None:")
with Indentation(gen):
gen.extend(req_subgen)
elif req_subgen.code:
gen(f"if {unknown_name} is None:")
with Indentation(gen):
gen.extend(req_subgen)
if subgen.code:
gen(f"if {unknown_name} is None:")
with Indentation(gen):
gen.extend(subgen)

gen("# }}}")
gen("")
Expand Down
23 changes: 23 additions & 0 deletions test/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import loopy as lp
import numpy as np
import pyopencl as cl
import pyopencl.array

from pyopencl.tools import \
pytest_generate_tests_for_pyopencl as pytest_generate_tests # noqa
Expand Down Expand Up @@ -140,6 +141,28 @@ def test_einsum_array_ops_triple_prod(ctx_factory, spec):
assert np.linalg.norm(out - ans) <= 1e-15


def test_einsum_with_variable_strides(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)

spec = "ijk,jl->il"
knl = lp.make_einsum(spec, ("a", "b"),
default_order=lp.auto, default_offset=lp.auto)

a_untransposed = np.random.randn(3, 5, 4)
b = np.random.randn(4, 5)

a = a_untransposed.transpose((0, 2, 1))
a_dev = cl.array.to_device(queue, a_untransposed).transpose((0, 2, 1))
assert a_dev.strides == a.strides

_evt, (result,) = knl(queue, a=a_dev, b=b)

ref = np.einsum(spec, a, b)

assert np.allclose(result.get(), ref)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit c37030a

Please sign in to comment.