Skip to content

Commit ae7f6b6

Browse files
[pallas:triton] Fix atomic min/max lowering for unsigned integers and float types
1 parent d5e5b42 commit ae7f6b6

File tree

2 files changed

+91
-5
lines changed

2 files changed

+91
-5
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,76 @@ def _atomic_rmw(
474474
result_type, op, ptr, val, mask=mask, sem=semantic, scope=sync_scope
475475
)
476476

477+
def _fp_bits_type(t: ir.Type) -> ir.Type:
478+
if ir.RankedTensorType.isinstance(t):
479+
t_type = ir.RankedTensorType(t)
480+
return ir.RankedTensorType.get(
481+
t_type.shape, _fp_bits_type(t_type.element_type), t_type.encoding
482+
)
483+
elif tt_dialect.PointerType.isinstance(t):
484+
ptr_type = tt_dialect.PointerType(t)
485+
return tt_dialect.PointerType.get(
486+
_fp_bits_type(ptr_type.pointee_type), ptr_type.address_space
487+
)
488+
else:
489+
assert isinstance(t, ir.FloatType)
490+
return ir.IntegerType.get_signless(t.width)
491+
492+
493+
def _expand_atomic_fp_min_max(
494+
atomic_type: primitives.AtomicOpType,
495+
ptr: ir.Value,
496+
val: ir.Value,
497+
mask: ir.Value | None = None,
498+
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
499+
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
500+
) -> ir.Value:
501+
"""
502+
Expands floating point min/max via sequence of integer min/max. Does not handle NaNs.
503+
504+
min:
505+
return atomic_smin(i_ptr, i_val) if i_val >= 0 else atomic_umax(i_ptr, i_val)
506+
507+
max:
508+
return atomic_smax(i_ptr, i_val) if i_val >= 0 else atomic_umin(i_ptr, i_val)
509+
510+
"""
511+
512+
if ir.RankedTensorType.isinstance(ptr.type):
513+
ptr_type = ir.RankedTensorType(ptr.type)
514+
element_type = tt_dialect.PointerType(ptr_type.element_type)
515+
result_type = ir.RankedTensorType.get(
516+
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
517+
)
518+
else:
519+
result_type = tt_dialect.PointerType(ptr.type).pointee_type
520+
521+
ptr_cast = tt_dialect.bitcast(_fp_bits_type(ptr.type), ptr)
522+
val_cast = tt_dialect.bitcast(_fp_bits_type(val.type), val)
523+
524+
zero = _full(val_cast.type, 0)
525+
pos_cmp = _greater_equal(val_cast, zero, signed=True)
526+
neg_cmp = _less_than(val_cast, zero, signed=True)
527+
528+
pos_mask = pos_cmp if mask is None else arith_dialect.andi(mask, pos_cmp)
529+
neg_mask = neg_cmp if mask is None else arith_dialect.andi(mask, neg_cmp)
530+
531+
pos_op, neg_op = (
532+
(tt_dialect.RMWOp.MAX, tt_dialect.RMWOp.UMIN)
533+
if atomic_type == primitives.AtomicOpType.MAX
534+
else (tt_dialect.RMWOp.MIN, tt_dialect.RMWOp.UMAX)
535+
)
536+
# Taken from triton's python/triton/language/semantic.py
537+
# Doesn't handle nans
538+
# TODO: Check what Pallas sematics should be
539+
pos_val = _atomic_rmw(
540+
pos_op, ptr_cast, val_cast, mask=pos_mask, semantic=semantic, sync_scope=sync_scope
541+
)
542+
neg_val = _atomic_rmw(
543+
neg_op, ptr_cast, val_cast, mask=neg_mask, semantic=semantic, sync_scope=sync_scope
544+
)
545+
result = arith_dialect.select(pos_cmp, pos_val, neg_val)
546+
return tt_dialect.bitcast(result_type, result)
477547

478548
@register_lowering(primitives.atomic_rmw_p)
479549
def _atomic_lowering_rule(
@@ -501,9 +571,23 @@ def _atomic_lowering_rule(
501571
else:
502572
op = tt_dialect.RMWOp.FADD
503573
elif atomic_type == primitives.AtomicOpType.MIN:
504-
op = tt_dialect.RMWOp.MIN
574+
if isinstance(val.type, ir.IntegerType):
575+
op = (
576+
tt_dialect.RMWOp.MIN
577+
if jnp.issubdtype(value_aval.dtype, jnp.signedinteger)
578+
else tt_dialect.RMWOp.UMIN
579+
)
580+
else:
581+
return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask)
505582
elif atomic_type == primitives.AtomicOpType.MAX:
506-
op = tt_dialect.RMWOp.MAX
583+
if isinstance(val.type, ir.IntegerType):
584+
op = (
585+
tt_dialect.RMWOp.MAX
586+
if jnp.issubdtype(value_aval.dtype, jnp.signedinteger)
587+
else tt_dialect.RMWOp.UMAX
588+
)
589+
else:
590+
return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask)
507591
elif atomic_type == primitives.AtomicOpType.AND:
508592
op = tt_dialect.RMWOp.AND
509593
elif atomic_type == primitives.AtomicOpType.OR:

tests/pallas/ops_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,12 +1863,14 @@ def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref):
18631863

18641864
@parameterized.named_parameters(
18651865
("add_i32", pl.atomic_add, np.array([1, 2, 3, 4], np.int32), np.sum),
1866-
("max_i", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max),
1866+
("max_i32", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max),
18671867
("min_i32", pl.atomic_min, np.array([1, 2, 3, 4], np.int32), np.min),
1868+
("max_u32", pl.atomic_max, np.array([1, 2, 3, 4], np.uint32), np.max),
1869+
("min_u32", pl.atomic_min, np.array([1, 2, 3, 4], np.uint32), np.min),
18681870
("add_f16", pl.atomic_add, np.array([1, 2, 3, 4], np.float16), np.sum),
18691871
("add_f32", pl.atomic_add, np.array([1, 2, 3, 4], np.float32), np.sum),
1870-
("max_f32", pl.atomic_max, np.array([1, 2, 3, 4], np.float32), np.max),
1871-
("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min),
1872+
("max_f32", pl.atomic_max, np.array([-2, -1, 0, 1], np.float32), np.max),
1873+
("min_f32", pl.atomic_min, np.array([-2, -1, 0, 1], np.float32), np.min),
18721874
)
18731875
def test_scalar_atomic(self, op, value, numpy_op):
18741876
# The Pallas TPU lowering currently supports only blocks of rank >= 1

0 commit comments

Comments
 (0)