@@ -474,6 +474,65 @@ def _atomic_rmw(
474
474
result_type , op , ptr , val , mask = mask , sem = semantic , scope = sync_scope
475
475
)
476
476
477
+ def _fp_bits_type (t : ir .Type ) -> ir .Type :
478
+ if ir .RankedTensorType .isinstance (t ):
479
+ t_type = ir .RankedTensorType (t )
480
+ return ir .RankedTensorType .get (
481
+ t_type .shape , _fp_bits_type (t_type .element_type ), t_type .encoding
482
+ )
483
+ elif tt_dialect .PointerType .isinstance (t ):
484
+ ptr_type = tt_dialect .PointerType (t )
485
+ return tt_dialect .PointerType .get (
486
+ _fp_bits_type (ptr_type .pointee_type ), ptr_type .address_space
487
+ )
488
+ else :
489
+ assert isinstance (t , ir .FloatType )
490
+ return ir .IntegerType .get_signless (t .width )
491
+
492
+
493
+ def _expand_atomic_fp_min_max (
494
+ atomic_type : primitives .AtomicOpType ,
495
+ ptr : ir .Value ,
496
+ val : ir .Value ,
497
+ mask : ir .Value | None = None ,
498
+ semantic : tt_dialect .MemSemantic = tt_dialect .MemSemantic .ACQUIRE_RELEASE ,
499
+ sync_scope : tt_dialect .MemSyncScope = tt_dialect .MemSyncScope .GPU ,
500
+ ) -> ir .Value :
501
+ if ir .RankedTensorType .isinstance (ptr .type ):
502
+ ptr_type = ir .RankedTensorType (ptr .type )
503
+ element_type = tt_dialect .PointerType (ptr_type .element_type )
504
+ result_type = ir .RankedTensorType .get (
505
+ ptr_type .shape , element_type .pointee_type , ptr_type .encoding
506
+ )
507
+ else :
508
+ result_type = tt_dialect .PointerType (ptr .type ).pointee_type
509
+
510
+ ptr_cast = tt_dialect .bitcast (_fp_bits_type (ptr .type ), ptr )
511
+ val_cast = tt_dialect .bitcast (_fp_bits_type (val .type ), val )
512
+
513
+ zero = _full (val_cast .type , 0 )
514
+ pos_cmp = _greater_equal (val_cast , zero , signed = True )
515
+ neg_cmp = _less_than (val_cast , zero , signed = True )
516
+
517
+ pos_mask = pos_cmp if mask is None else arith_dialect .andi (mask , pos_cmp )
518
+ neg_mask = neg_cmp if mask is None else arith_dialect .andi (mask , neg_cmp )
519
+
520
+ pos_op , neg_op = (
521
+ (tt_dialect .RMWOp .MAX , tt_dialect .RMWOp .UMIN )
522
+ if atomic_type == primitives .AtomicOpType .MAX
523
+ else (tt_dialect .RMWOp .MIN , tt_dialect .RMWOp .UMAX )
524
+ )
525
+ # Taken from triton's python/triton/language/semantic.py
526
+ # Doesn't handle nans
527
+ # TODO: Check what Pallas sematics should be
528
+ pos_val = _atomic_rmw (
529
+ pos_op , ptr_cast , val_cast , mask = pos_mask , semantic = semantic , sync_scope = sync_scope
530
+ )
531
+ neg_val = _atomic_rmw (
532
+ neg_op , ptr_cast , val_cast , mask = neg_mask , semantic = semantic , sync_scope = sync_scope
533
+ )
534
+ result = arith_dialect .select (pos_cmp , pos_val , neg_val )
535
+ return tt_dialect .bitcast (result_type , result )
477
536
478
537
@register_lowering (primitives .atomic_rmw_p )
479
538
def _atomic_lowering_rule (
@@ -501,9 +560,23 @@ def _atomic_lowering_rule(
501
560
else :
502
561
op = tt_dialect .RMWOp .FADD
503
562
elif atomic_type == primitives .AtomicOpType .MIN :
504
- op = tt_dialect .RMWOp .MIN
563
+ if isinstance (val .type , ir .IntegerType ):
564
+ op = (
565
+ tt_dialect .RMWOp .MIN
566
+ if jnp .issubdtype (value_aval .dtype , jnp .signedinteger )
567
+ else tt_dialect .RMWOp .UMIN
568
+ )
569
+ else :
570
+ return _expand_atomic_fp_min_max (atomic_type , ptr , val , mask = mask )
505
571
elif atomic_type == primitives .AtomicOpType .MAX :
506
- op = tt_dialect .RMWOp .MAX
572
+ if isinstance (val .type , ir .IntegerType ):
573
+ op = (
574
+ tt_dialect .RMWOp .MAX
575
+ if jnp .issubdtype (value_aval .dtype , jnp .signedinteger )
576
+ else tt_dialect .RMWOp .UMAX
577
+ )
578
+ else :
579
+ return _expand_atomic_fp_min_max (atomic_type , ptr , val , mask = mask )
507
580
elif atomic_type == primitives .AtomicOpType .AND :
508
581
op = tt_dialect .RMWOp .AND
509
582
elif atomic_type == primitives .AtomicOpType .OR :
0 commit comments