Skip to content

Commit

Permalink
use read/write_events everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Oct 11, 2024
1 parent c43d2c5 commit a62047c
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 43 deletions.
14 changes: 7 additions & 7 deletions pyopencl/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _extract_extra_args_types_values(extra_args):
if isinstance(val, cl.array.Array):
extra_args_types.append(VectorArg(val.dtype, name, with_offset=False))
extra_args_values.append(val)
extra_wait_for.extend(val.events)
extra_wait_for.extend(val.write_events)
elif isinstance(val, np.generic):
extra_args_types.append(ScalarArg(val.dtype, name))
extra_args_values.append(val)
Expand Down Expand Up @@ -1163,7 +1163,7 @@ def __call__(self, queue, n_objects, *args, **kwargs):
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
wait_for.extend(arg_val.events)
wait_for.extend(arg_val.write_events)
else:
data_args.append(arg_val)

Expand All @@ -1182,7 +1182,7 @@ def __call__(self, queue, n_objects, *args, **kwargs):
counts = cl.array.empty(queue,
(n_objects + 1), index_dtype, allocator=allocator)
counts[-1] = 0
wait_for = wait_for + counts.events
wait_for = wait_for + counts.write_events

# The scan will turn the "counts" array into the "starts" array
# in-place.
Expand Down Expand Up @@ -1236,7 +1236,7 @@ def __call__(self, queue, n_objects, *args, **kwargs):
info_record.nonempty_indices,
info_record.compressed_indices,
info_record.num_nonempty_lists,
wait_for=[count_event, *info_record.compressed_indices.events])
wait_for=[count_event, *info_record.compressed_indices.write_events])

info_record.starts = compressed_counts

Expand Down Expand Up @@ -1265,13 +1265,13 @@ def __call__(self, queue, n_objects, *args, **kwargs):
evt = scan_kernel(
starts_ary,
size=info_record.num_nonempty_lists,
wait_for=starts_ary.events)
wait_for=starts_ary.write_events)
else:
evt = scan_kernel(starts_ary, wait_for=[count_event],
size=n_objects)

starts_ary.setitem(0, 0, queue=queue, wait_for=[evt])
scan_events.extend(starts_ary.events)
scan_events.extend(starts_ary.write_events)

# retrieve count
info_record.count = int(starts_ary[-1].get())
Expand Down Expand Up @@ -1433,7 +1433,7 @@ def __call__(self, queue, keys, values, nkeys,

starts = (cl.array.empty(queue, (nkeys+1), starts_dtype, allocator=allocator)
.fill(len(values_sorted_by_key), wait_for=[evt]))
evt, = starts.events
evt, = starts.write_events

evt = knl_info.start_finder(starts, keys_sorted_by_key,
range=slice(len(keys_sorted_by_key)),
Expand Down
2 changes: 1 addition & 1 deletion pyopencl/bitonic_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __call__(self, arr, idx=None, queue=None, wait_for=None, axis=0):

if wait_for is None:
wait_for = []
wait_for = wait_for + arr.events
wait_for = wait_for + arr.write_events

last_evt = cl.enqueue_marker(queue, wait_for=wait_for)

Expand Down
50 changes: 21 additions & 29 deletions pyopencl/clmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@

import numpy as np

import pyopencl.array as cl_array
import pyopencl.elementwise as elementwise
from pyopencl.array import _get_common_dtype
from pyopencl.array import _get_common_dtype, elwise_kernel_runner


def _make_unary_array_func(name):
@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def knl_runner(result, arg):
if arg.dtype.kind == "c":
from pyopencl.elementwise import complex_dtype_to_name
Expand All @@ -43,8 +42,7 @@ def knl_runner(result, arg):

def f(array, queue=None):
result = array._new_like_me(queue=queue)
event1 = knl_runner(result, array, queue=queue)
result.add_event(event1)
knl_runner(result, array, queue=queue)
return result

return f
Expand All @@ -60,13 +58,13 @@ def f(array, queue=None):
asinpi = _make_unary_array_func("asinpi")


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _atan2(result, arg1, arg2):
return elementwise.get_float_binary_func_kernel(
result.context, "atan2", arg1.dtype, arg2.dtype, result.dtype)


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _atan2pi(result, arg1, arg2):
return elementwise.get_float_binary_func_kernel(
result.context, "atan2pi", arg1.dtype, arg2.dtype, result.dtype)
Expand All @@ -81,7 +79,7 @@ def atan2(y, x, queue=None):
"""
queue = queue or y.queue
result = y._new_like_me(_get_common_dtype(y, x, queue))
result.add_event(_atan2(result, y, x, queue=queue))
_atan2(result, y, x, queue=queue)
return result


Expand All @@ -95,7 +93,7 @@ def atan2pi(y, x, queue=None):
"""
queue = queue or y.queue
result = y._new_like_me(_get_common_dtype(y, x, queue))
result.add_event(_atan2pi(result, y, x, queue=queue))
_atan2pi(result, y, x, queue=queue)
return result


Expand All @@ -122,7 +120,7 @@ def atan2pi(y, x, queue=None):
# TODO: fmin


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _fmod(result, arg, mod):
return elementwise.get_fmod_kernel(result.context, result.dtype,
arg.dtype, mod.dtype)
Expand All @@ -133,13 +131,13 @@ def fmod(arg, mod, queue=None):
for each element in ``arg`` and ``mod``."""
queue = (queue or arg.queue) or mod.queue
result = arg._new_like_me(_get_common_dtype(arg, mod, queue))
result.add_event(_fmod(result, arg, mod, queue=queue))
_fmod(result, arg, mod, queue=queue)
return result

# TODO: fract


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _frexp(sig, expt, arg):
return elementwise.get_frexp_kernel(sig.context, sig.dtype,
expt.dtype, arg.dtype)
Expand All @@ -151,9 +149,7 @@ def frexp(arg, queue=None):
"""
sig = arg._new_like_me(queue=queue)
expt = arg._new_like_me(queue=queue, dtype=np.int32)
event1 = _frexp(sig, expt, arg, queue=queue)
sig.add_event(event1)
expt.add_event(event1)
_frexp(sig, expt, arg, queue=queue, noutputs=2)
return sig, expt

# TODO: hypot
Expand All @@ -162,7 +158,7 @@ def frexp(arg, queue=None):
ilogb = _make_unary_array_func("ilogb")


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _ldexp(result, sig, exp):
return elementwise.get_ldexp_kernel(result.context, result.dtype,
sig.dtype, exp.dtype)
Expand All @@ -174,7 +170,7 @@ def ldexp(significand, exponent, queue=None):
``result = significand * 2**exponent``.
"""
result = significand._new_like_me(queue=queue)
result.add_event(_ldexp(result, significand, exponent))
_ldexp(result, significand, exponent)
return result


Expand All @@ -192,7 +188,7 @@ def ldexp(significand, exponent, queue=None):
# TODO: minmag


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _modf(intpart, fracpart, arg):
return elementwise.get_modf_kernel(intpart.context, intpart.dtype,
fracpart.dtype, arg.dtype)
Expand All @@ -204,9 +200,7 @@ def modf(arg, queue=None):
"""
intpart = arg._new_like_me(queue=queue)
fracpart = arg._new_like_me(queue=queue)
event1 = _modf(intpart, fracpart, arg, queue=queue)
fracpart.add_event(event1)
intpart.add_event(event1)
_modf(intpart, fracpart, arg, queue=queue, noutputs=2)
return fracpart, intpart


Expand Down Expand Up @@ -239,19 +233,19 @@ def modf(arg, queue=None):
# TODO: table 6.10, integer functions
# TODO: table 6.12, clamp et al

@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _bessel_jn(result, n, x):
return elementwise.get_bessel_kernel(result.context, "j", result.dtype,
np.dtype(type(n)), x.dtype)


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _bessel_yn(result, n, x):
return elementwise.get_bessel_kernel(result.context, "y", result.dtype,
np.dtype(type(n)), x.dtype)


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _hankel_01(h0, h1, x):
if h0.dtype != h1.dtype:
raise TypeError("types of h0 and h1 must match")
Expand All @@ -261,20 +255,18 @@ def _hankel_01(h0, h1, x):

def bessel_jn(n, x, queue=None):
result = x._new_like_me(queue=queue)
result.add_event(_bessel_jn(result, n, x, queue=queue))
_bessel_jn(result, n, x, queue=queue)
return result


def bessel_yn(n, x, queue=None):
result = x._new_like_me(queue=queue)
result.add_event(_bessel_yn(result, n, x, queue=queue))
_bessel_yn(result, n, x, queue=queue)
return result


def hankel_01(x, queue=None):
h0 = x._new_like_me(queue=queue)
h1 = x._new_like_me(queue=queue)
event1 = _hankel_01(h0, h1, x, queue=queue)
h0.add_event(event1)
h1.add_event(event1)
_hankel_01(h0, h1, x, queue=queue, noutputs=2)
return h0, h1
2 changes: 1 addition & 1 deletion pyopencl/clrandom.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _fill(self, distribution, ary, scale, shift, queue=None):
gsize, lsize = _splay(queue.device, ary.size)

evt = knl(queue, gsize, lsize, *args)
ary.add_event(evt)
ary.add_write_event(evt)

self.counter[0] += n * counter_multiplier
c1_incr, self.counter[0] = divmod(self.counter[0], self.counter_max)
Expand Down
2 changes: 1 addition & 1 deletion pyopencl/invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def add_buf_arg(arg_idx, typechar, expr_str):
cl_arg_idx += 1

if in_enqueue:
wait_for_parts .append(f"{arg_var}.events")
wait_for_parts.append(f"{arg_var}.write_events")

continue

Expand Down
4 changes: 2 additions & 2 deletions pyopencl/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
invocation_args.append(arg.base_data)
if arg_tp.with_offset:
invocation_args.append(arg.offset)
wait_for.extend(arg.events)
wait_for.extend(arg.write_events)
else:
invocation_args.append(arg)

Expand Down Expand Up @@ -523,7 +523,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
wait_for=wait_for)
wait_for = [last_evt]

result.add_event(last_evt)
result.add_write_event(last_evt)

if group_count == 1:
if return_event:
Expand Down
4 changes: 2 additions & 2 deletions pyopencl/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,7 +1533,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
wait_for.extend(arg_val.events)
wait_for.extend(arg_val.write_events)
else:
data_args.append(arg_val)

Expand Down Expand Up @@ -1766,7 +1766,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
wait_for.extend(arg_val.events)
wait_for.extend(arg_val.write_events)
else:
data_args.append(arg_val)

Expand Down

0 comments on commit a62047c

Please sign in to comment.