Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,9 +561,9 @@ def make_einsum(spec, arg_names, **knl_creation_kwargs):
for idx in all_indices
)

if "name" not in knl_creation_kwargs:
knl_creation_kwargs["name"] = "einsum%dto%d_kernel" % (
len(all_indices), len(out_indices))
knl_creation_kwargs.setdefault("name", "einsum%dto%d_kernel" % (
len(all_indices), len(out_indices)))
knl_creation_kwargs.setdefault("lang_version", MOST_RECENT_LANGUAGE_VERSION)

return make_kernel("{[%s]: %s}" % (",".join(sorted(all_indices)), constraints),
[Assignment(lhs, rhs)],
Expand Down
6 changes: 3 additions & 3 deletions loopy/target/c/c_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def handle_alloc(self, gen, arg, kernel_arg, strify, skip_arg_checks):
#check strides
if not skip_arg_checks:
strides_check_expr = self.get_strides_check_expr(
(strify(s) for s in sym_shape),
(strify(s) for s in sym_strides),
(strify(s) for s in expected_strides))
[strify(s) for s in sym_shape],
[strify(s) for s in sym_strides],
[strify(s) for s in expected_strides])
gen("assert %(strides_check)s, "
"'Strides of loopy created array %(name)s, "
"do not match expected.'" %
Expand Down
33 changes: 23 additions & 10 deletions loopy/target/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,26 @@ def handle_alloc(self, gen, arg, kernel_arg, strify, skip_arg_checks):
def get_arg_pass(self, arg):
raise NotImplementedError()

def get_strides_check_expr(self, shape, strides, sym_strides):
def get_strides_check_expr(self, shape, strides, expected_strides):
assert len(shape) == len(strides) == len(expected_strides)

# Returns an expression suitable for use for checking the strides of an
# argument. Arguments should be sequences of strings.
return " and ".join(
"(%s == 1 or %s == %s)" % elem
for elem in zip(shape, strides, sym_strides)) or "True"

# Shape axes of length 1 are ignored because strides along these
# axes are never used: The only valid index is 1.
match_expr = " and ".join(
f"({shape_i} == 1 or {strides_i} == {expected_strides_i})"
for shape_i, strides_i, expected_strides_i
in zip(shape, strides, expected_strides)) or "True"

if shape:
# If any shape component is zero, the array is empty and the strides
# don't matter.
match_expr = (f"({match_expr})"
+ "".join(f" or not {shape_i}" for shape_i in shape))

return match_expr

# {{{ arg setup

Expand Down Expand Up @@ -545,21 +559,20 @@ def strify_tuple(t):
gen("if not (%s):"
% self.get_strides_check_expr(
shape, strides,
(strify(s) for s in sym_strides)))
[strify(s) for s in sym_strides]))
with Indentation(gen):
gen("_lpy_got = tuple(stride "
"for (dim, stride) in zip(%s.shape, %s.strides) "
"if dim > 1)"
")"
% (arg.name, arg.name))
gen("_lpy_expected = tuple(stride "
"for (dim, stride) in zip(%s.shape, %s) "
"if dim > 1)"
")"
% (arg.name, strify_tuple(sym_strides)))

gen('raise TypeError("strides mismatch on '
gen('raise ValueError("strides mismatch on '
"argument '%s' "
"(after removing unit length dims, "
'got: %%s, expected: %%s)" '
'(got: %%s, expected: %%s)" '
"%% (_lpy_got, _lpy_expected))"
% arg.name)

Expand Down
2 changes: 1 addition & 1 deletion test/test_loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2752,7 +2752,7 @@ def test_shape_mismatch_check(ctx_factory):
if t_unit["loopy_kernel"].options.skip_arg_checks:
pytest.skip("args checks disabled, cannot check")

with pytest.raises(TypeError, match="strides mismatch"):
with pytest.raises(ValueError, match="strides mismatch"):
t_unit(queue, a=a, b=b)


Expand Down
30 changes: 30 additions & 0 deletions test/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,36 @@ def test_empty_array_output(ctx_factory):
assert out.shape == (0,)


def test_empty_array_stride_check(ctx_factory):
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)

einsum = lp.make_einsum("mij,j->mi", ["a", "x"])
einsum(cq, a=np.random.randn(3, 0, 5), x=np.random.randn(5))

if einsum.default_entrypoint.options.skip_arg_checks:
pytest.skip("args checks disabled, cannot check")

with pytest.raises(ValueError):
einsum(cq, a=np.random.randn(3, 2, 5).copy(order="F"), x=np.random.randn(5))


def test_empty_array_stride_check_fortran(ctx_factory):
# https://github.com/inducer/loopy/issues/583
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)

import pyopencl.array as cla

a_f = cla.Array(queue, (0, 2), np.float64, order="F")

knl = lp.make_kernel(
"{ [i,j]: 0<=i<n and 0<=j<m }",
"output[i,j] = sqrt(input[i,j])")

knl(queue, input=a_f)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down