Skip to content

Commit a9645d7

Browse files
[pallas:triton] Fix atomic min/max lowering for unsigned integers and float types
1 parent d5e5b42 commit a9645d7

File tree

2 files changed

+80
-5
lines changed

2 files changed

+80
-5
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,65 @@ def _atomic_rmw(
474474
result_type, op, ptr, val, mask=mask, sem=semantic, scope=sync_scope
475475
)
476476

477+
def _fp_bits_type(t: ir.Type) -> ir.Type:
478+
if ir.RankedTensorType.isinstance(t):
479+
t_type = ir.RankedTensorType(t)
480+
return ir.RankedTensorType.get(
481+
t_type.shape, _fp_bits_type(t_type.element_type), t_type.encoding
482+
)
483+
elif tt_dialect.PointerType.isinstance(t):
484+
ptr_type = tt_dialect.PointerType(t)
485+
return tt_dialect.PointerType.get(
486+
_fp_bits_type(ptr_type.pointee_type), ptr_type.address_space
487+
)
488+
else:
489+
assert isinstance(t, ir.FloatType)
490+
return ir.IntegerType.get_signless(t.width)
491+
492+
493+
def _expand_atomic_fp_min_max(
494+
atomic_type: primitives.AtomicOpType,
495+
ptr: ir.Value,
496+
val: ir.Value,
497+
mask: ir.Value | None = None,
498+
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
499+
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
500+
) -> ir.Value:
501+
if ir.RankedTensorType.isinstance(ptr.type):
502+
ptr_type = ir.RankedTensorType(ptr.type)
503+
element_type = tt_dialect.PointerType(ptr_type.element_type)
504+
result_type = ir.RankedTensorType.get(
505+
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
506+
)
507+
else:
508+
result_type = tt_dialect.PointerType(ptr.type).pointee_type
509+
510+
ptr_cast = tt_dialect.bitcast(_fp_bits_type(ptr.type), ptr)
511+
val_cast = tt_dialect.bitcast(_fp_bits_type(val.type), val)
512+
513+
zero = _full(val_cast.type, 0)
514+
pos_cmp = _greater_equal(val_cast, zero, signed=True)
515+
neg_cmp = _less_than(val_cast, zero, signed=True)
516+
517+
pos_mask = pos_cmp if mask is None else arith_dialect.andi(mask, pos_cmp)
518+
neg_mask = neg_cmp if mask is None else arith_dialect.andi(mask, neg_cmp)
519+
520+
pos_op, neg_op = (
521+
(tt_dialect.RMWOp.MAX, tt_dialect.RMWOp.UMIN)
522+
if atomic_type == primitives.AtomicOpType.MAX
523+
else (tt_dialect.RMWOp.MIN, tt_dialect.RMWOp.UMAX)
524+
)
525+
# Taken from triton's python/triton/language/semantic.py
526+
# Doesn't handle nans
527+
# TODO: Check what Pallas sematics should be
528+
pos_val = _atomic_rmw(
529+
pos_op, ptr_cast, val_cast, mask=pos_mask, semantic=semantic, sync_scope=sync_scope
530+
)
531+
neg_val = _atomic_rmw(
532+
neg_op, ptr_cast, val_cast, mask=neg_mask, semantic=semantic, sync_scope=sync_scope
533+
)
534+
result = arith_dialect.select(pos_cmp, pos_val, neg_val)
535+
return tt_dialect.bitcast(result_type, result)
477536

478537
@register_lowering(primitives.atomic_rmw_p)
479538
def _atomic_lowering_rule(
@@ -501,9 +560,23 @@ def _atomic_lowering_rule(
501560
else:
502561
op = tt_dialect.RMWOp.FADD
503562
elif atomic_type == primitives.AtomicOpType.MIN:
504-
op = tt_dialect.RMWOp.MIN
563+
if isinstance(val.type, ir.IntegerType):
564+
op = (
565+
tt_dialect.RMWOp.MIN
566+
if jnp.issubdtype(value_aval.dtype, jnp.signedinteger)
567+
else tt_dialect.RMWOp.UMIN
568+
)
569+
else:
570+
return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask)
505571
elif atomic_type == primitives.AtomicOpType.MAX:
506-
op = tt_dialect.RMWOp.MAX
572+
if isinstance(val.type, ir.IntegerType):
573+
op = (
574+
tt_dialect.RMWOp.MAX
575+
if jnp.issubdtype(value_aval.dtype, jnp.signedinteger)
576+
else tt_dialect.RMWOp.UMAX
577+
)
578+
else:
579+
return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask)
507580
elif atomic_type == primitives.AtomicOpType.AND:
508581
op = tt_dialect.RMWOp.AND
509582
elif atomic_type == primitives.AtomicOpType.OR:

tests/pallas/ops_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,12 +1863,14 @@ def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref):
18631863

18641864
@parameterized.named_parameters(
18651865
("add_i32", pl.atomic_add, np.array([1, 2, 3, 4], np.int32), np.sum),
1866-
("max_i", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max),
1866+
("max_i32", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max),
18671867
("min_i32", pl.atomic_min, np.array([1, 2, 3, 4], np.int32), np.min),
1868+
("max_u32", pl.atomic_max, np.array([1, 2, 3, 4], np.uint32), np.max),
1869+
("min_u32", pl.atomic_min, np.array([1, 2, 3, 4], np.uint32), np.min),
18681870
("add_f16", pl.atomic_add, np.array([1, 2, 3, 4], np.float16), np.sum),
18691871
("add_f32", pl.atomic_add, np.array([1, 2, 3, 4], np.float32), np.sum),
1870-
("max_f32", pl.atomic_max, np.array([1, 2, 3, 4], np.float32), np.max),
1871-
("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min),
1872+
("max_f32", pl.atomic_max, np.array([-2, -1, 0, 1], np.float32), np.max),
1873+
("min_f32", pl.atomic_min, np.array([-2, -1, 0, 1], np.float32), np.min),
18721874
)
18731875
def test_scalar_atomic(self, op, value, numpy_op):
18741876
# The Pallas TPU lowering currently supports only blocks of rank >= 1

0 commit comments

Comments
 (0)