Skip to content

Commit 16e2488

Browse files
adarshyogaDiptorup Deb
authored andcommitted
implementation of compare exchange with accompanying test cases
1 parent df32f71 commit 16e2488

File tree

4 files changed

+346
-3
lines changed

4 files changed

+346
-3
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from .spv_fn_generator import (
3636
get_or_insert_atomic_load_fn,
37+
get_or_insert_spv_atomic_compare_exchange_fn,
3738
get_or_insert_spv_atomic_exchange_fn,
3839
get_or_insert_spv_atomic_store_fn,
3940
)
@@ -323,6 +324,108 @@ def _intrinsic_exchange_gen(context, builder, sig, args):
323324
return sig, _intrinsic_exchange_gen
324325

325326

327+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
328+
def _intrinsic_compare_exchange(
329+
ty_context, # pylint: disable=unused-argument
330+
ty_atomic_ref,
331+
ty_expected_ref,
332+
ty_desired,
333+
ty_expected_idx,
334+
):
335+
sig = types.boolean(
336+
ty_atomic_ref, ty_expected_ref, ty_desired, ty_expected_idx
337+
)
338+
339+
def _intrinsic_compare_exchange_gen(context, builder, sig, args):
340+
# get pointer to expected[expected_idx]
341+
data_attr = builder.extract_value(
342+
args[1],
343+
context.data_model_manager.lookup(sig.args[1]).get_field_position(
344+
"data"
345+
),
346+
)
347+
with builder.goto_entry_block():
348+
ptr_to_data_attr = builder.alloca(data_attr.type)
349+
builder.store(data_attr, ptr_to_data_attr)
350+
expected_ref_ptr = builder.gep(
351+
builder.load(ptr_to_data_attr), [args[3]]
352+
)
353+
354+
expected_arg = builder.load(expected_ref_ptr)
355+
desired_arg = args[2]
356+
atomic_ref_ptr = builder.extract_value(
357+
args[0],
358+
context.data_model_manager.lookup(sig.args[0]).get_field_position(
359+
"ref"
360+
),
361+
)
362+
# add conditional bitcast for atomic_ref pointer,
363+
# expected[expected_idx], and desired
364+
if sig.args[0].dtype == types.float32:
365+
atomic_ref_ptr = builder.bitcast(
366+
atomic_ref_ptr,
367+
llvmir.PointerType(
368+
llvmir.IntType(32), addrspace=sig.args[0].address_space
369+
),
370+
)
371+
expected_arg = builder.bitcast(expected_arg, llvmir.IntType(32))
372+
desired_arg = builder.bitcast(desired_arg, llvmir.IntType(32))
373+
elif sig.args[0].dtype == types.float64:
374+
atomic_ref_ptr = builder.bitcast(
375+
atomic_ref_ptr,
376+
llvmir.PointerType(
377+
llvmir.IntType(64), addrspace=sig.args[0].address_space
378+
),
379+
)
380+
expected_arg = builder.bitcast(expected_arg, llvmir.IntType(64))
381+
desired_arg = builder.bitcast(desired_arg, llvmir.IntType(64))
382+
383+
atomic_cmpexchg_fn_args = [
384+
atomic_ref_ptr,
385+
context.get_constant(
386+
types.int32, get_scope(sig.args[0].memory_scope)
387+
),
388+
context.get_constant(
389+
types.int32,
390+
get_memory_semantics_mask(sig.args[0].memory_order),
391+
),
392+
context.get_constant(
393+
types.int32,
394+
get_memory_semantics_mask(sig.args[0].memory_order),
395+
),
396+
desired_arg,
397+
expected_arg,
398+
]
399+
400+
ret_val = builder.call(
401+
get_or_insert_spv_atomic_compare_exchange_fn(
402+
context, builder.module, sig.args[0]
403+
),
404+
atomic_cmpexchg_fn_args,
405+
)
406+
407+
# compare_exchange returns the old value stored in AtomicRef object.
408+
# If the return value is same as expected, then compare_exchange
409+
# succeeded in replacing AtomicRef object with desired.
410+
# If the return value is not same as expected, then store return
411+
# value in expected.
412+
# In either case, return result of cmp instruction.
413+
is_cmp_exchg_success = builder.icmp_signed("==", ret_val, expected_arg)
414+
415+
with builder.if_else(is_cmp_exchg_success) as (then, otherwise):
416+
with then:
417+
pass
418+
with otherwise:
419+
if sig.args[0].dtype == types.float32:
420+
ret_val = builder.bitcast(ret_val, llvmir.FloatType())
421+
elif sig.args[0].dtype == types.float64:
422+
ret_val = builder.bitcast(ret_val, llvmir.DoubleType())
423+
builder.store(ret_val, expected_ref_ptr)
424+
return is_cmp_exchg_success
425+
426+
return sig, _intrinsic_compare_exchange_gen
427+
428+
326429
def _check_if_supported_ref(ref):
327430
supported = True
328431

@@ -689,3 +792,94 @@ def ol_exchange_impl(atomic_ref, val):
689792
return _intrinsic_exchange(atomic_ref, val)
690793

691794
return ol_exchange_impl
795+
796+
797+
@overload_method(
798+
AtomicRefType,
799+
"compare_exchange_weak",
800+
target=DPEX_KERNEL_EXP_TARGET_NAME,
801+
)
802+
def ol_compare_exchange_weak(
803+
atomic_ref, expected_ref, desired, expected_idx=0
804+
): # pylint: disable=unused-argument
805+
"""SPIR-V overload for
806+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.compare_exchange_weak`.
807+
808+
Generates the same LLVM IR instruction as dpcpp for the
809+
`atomic_ref::compare_exchange_weak` function.
810+
811+
Raises:
812+
TypingError: When the dtype of the value passed to `compare_exchange_weak`
813+
does not match the dtype of the AtomicRef type.
814+
"""
815+
816+
_check_if_supported_ref(expected_ref)
817+
818+
if atomic_ref.dtype != expected_ref.dtype:
819+
raise errors.TypingError(
820+
f"Type of value to compare_exchange_weak: {expected_ref} does not match the "
821+
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
822+
)
823+
824+
if atomic_ref.dtype != desired:
825+
raise errors.TypingError(
826+
f"Type of value to compare_exchange_weak: {desired} does not match the "
827+
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
828+
)
829+
830+
def ol_compare_exchange_weak_impl(
831+
atomic_ref, expected_ref, desired, expected_idx=0
832+
):
833+
# pylint: disable=no-value-for-parameter
834+
return _intrinsic_compare_exchange(
835+
atomic_ref, expected_ref, desired, expected_idx
836+
)
837+
838+
return ol_compare_exchange_weak_impl
839+
840+
841+
@overload_method(
842+
AtomicRefType,
843+
"compare_exchange_strong",
844+
target=DPEX_KERNEL_EXP_TARGET_NAME,
845+
)
846+
def ol_compare_exchange_strong(
847+
atomic_ref,
848+
expected_ref,
849+
desired,
850+
expected_idx=0, # pylint: disable=unused-argument
851+
):
852+
"""SPIR-V overload for
853+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.compare_exchange_strong`.
854+
855+
Generates the same LLVM IR instruction as dpcpp for the
856+
`atomic_ref::compare_exchange_strong` function.
857+
858+
Raises:
859+
TypingError: When the dtype of the value passed to `compare_exchange_strong`
860+
does not match the dtype of the AtomicRef type.
861+
"""
862+
863+
_check_if_supported_ref(expected_ref)
864+
865+
if atomic_ref.dtype != expected_ref.dtype:
866+
raise errors.TypingError(
867+
f"Type of value to compare_exchange_strong: {expected_ref} does not match the "
868+
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
869+
)
870+
871+
if atomic_ref.dtype != desired:
872+
raise errors.TypingError(
873+
f"Type of value to compare_exchange_strong: {desired} does not match the "
874+
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
875+
)
876+
877+
def ol_compare_exchange_strong_impl(
878+
atomic_ref, expected_ref, desired, expected_idx=0
879+
):
880+
# pylint: disable=no-value-for-parameter
881+
return _intrinsic_compare_exchange(
882+
atomic_ref, expected_ref, desired, expected_idx
883+
)
884+
885+
return ol_compare_exchange_strong_impl

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/spv_fn_generator.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,59 @@ def get_or_insert_spv_atomic_exchange_fn(context, module, atomic_ref_ty):
119119
fn.calling_convention = CC_SPIR_FUNC
120120

121121
return fn
122+
123+
124+
def get_or_insert_spv_atomic_compare_exchange_fn(
125+
context, module, atomic_ref_ty
126+
):
127+
"""
128+
Gets or inserts a declaration for a __spirv_AtomicCompareExchange call into the
129+
specified LLVM IR module.
130+
"""
131+
atomic_ref_dtype = atomic_ref_ty.dtype
132+
133+
# Spirv spec requires arguments and return type to be of integer types.
134+
# That is why the type is changed from float to int
135+
# while maintaining the bit-width.
136+
# During function call, bitcasting is performed
137+
# to adhere to this convention.
138+
if atomic_ref_dtype == types.float32:
139+
atomic_ref_dtype = types.uint32
140+
elif atomic_ref_dtype == types.float64:
141+
atomic_ref_dtype = types.uint64
142+
143+
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
144+
ptr_type.addrspace = atomic_ref_ty.address_space
145+
atomic_cmpexchg_fn_retty = context.get_value_type(atomic_ref_dtype)
146+
147+
atomic_cmpexchg_fn_arg_types = [
148+
ptr_type,
149+
llvmir.IntType(32),
150+
llvmir.IntType(32),
151+
llvmir.IntType(32),
152+
context.get_value_type(atomic_ref_dtype),
153+
context.get_value_type(atomic_ref_dtype),
154+
]
155+
156+
mangled_fn_name = ext_itanium_mangler.mangle_ext(
157+
"__spirv_AtomicCompareExchange",
158+
[
159+
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
160+
"__spv.Scope.Flag",
161+
"__spv.MemorySemanticsMask.Flag",
162+
"__spv.MemorySemanticsMask.Flag",
163+
atomic_ref_dtype,
164+
atomic_ref_dtype,
165+
],
166+
)
167+
168+
fn = cgutils.get_or_insert_function(
169+
module,
170+
llvmir.FunctionType(
171+
atomic_cmpexchg_fn_retty, atomic_cmpexchg_fn_arg_types
172+
),
173+
mangled_fn_name,
174+
)
175+
fn.calling_convention = CC_SPIR_FUNC
176+
177+
return fn

numba_dpex/kernel_api/atomic_ref.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,55 @@ def exchange(self, val):
200200
old = self._ref[self._index]
201201
self._ref[self._index] = val
202202
return old
203+
204+
def compare_exchange_weak(self, expected, desired, expected_idx=0):
205+
"""Compares the value of the object referenced by the AtomicRef
206+
against the value of ``expected[expected_idx]``.
207+
If the values are equal, attempts to replace the value of the
208+
referenced object with the value of ``desired``.
209+
Otherwise assigns the original value of the
210+
referenced object to ``expected[expected_idx]``.
211+
212+
Args:
213+
expected : Vector containing the expected value of the
214+
object referenced by the AtomicRef.
215+
desired : Value that replaces the value of the object
216+
referenced by the AtomicRef.
217+
expected_idx: Offset in `expected` vector where the expected
218+
value of the object referenced by the AtomicRef is present.
219+
220+
Returns: Returns ``True`` if the comparison operation and
221+
replacement operation were successful.
222+
223+
"""
224+
if self._ref[self._index] == expected[expected_idx]:
225+
self._ref[self._index] = desired
226+
return True
227+
expected[expected_idx] = self._ref[self._index]
228+
return False
229+
230+
def compare_exchange_strong(self, expected, desired, expected_idx=0):
231+
"""Compares the value of the object referenced by the AtomicRef
232+
against the value of ``expected[expected_idx]``.
233+
If the values are equal, replaces the value of the
234+
referenced object with the value of ``desired``.
235+
Otherwise assigns the original value of the
236+
referenced object to ``expected[expected_idx]``.
237+
238+
Args:
239+
expected : Vector containing the expected value of the
240+
object referenced by the AtomicRef.
241+
desired : Value that replaces the value of the object
242+
referenced by the AtomicRef.
243+
expected_idx: Offset in `expected` vector where the expected
244+
value of the object referenced by the AtomicRef is present.
245+
246+
Returns: Returns ``True`` if the comparison operation and
247+
replacement operation were successful.
248+
249+
"""
250+
if self._ref[self._index] == expected[expected_idx]:
251+
self._ref[self._index] = desired
252+
return True
253+
expected[expected_idx] = self._ref[self._index]
254+
return False

0 commit comments

Comments
 (0)