@@ -474,6 +474,76 @@ def _atomic_rmw(
474
474
result_type , op , ptr , val , mask = mask , sem = semantic , scope = sync_scope
475
475
)
476
476
477
+ def _fp_bits_type (t : ir .Type ) -> ir .Type :
478
+ if ir .RankedTensorType .isinstance (t ):
479
+ t_type = ir .RankedTensorType (t )
480
+ return ir .RankedTensorType .get (
481
+ t_type .shape , _fp_bits_type (t_type .element_type ), t_type .encoding
482
+ )
483
+ elif tt_dialect .PointerType .isinstance (t ):
484
+ ptr_type = tt_dialect .PointerType (t )
485
+ return tt_dialect .PointerType .get (
486
+ _fp_bits_type (ptr_type .pointee_type ), ptr_type .address_space
487
+ )
488
+ else :
489
+ assert isinstance (t , ir .FloatType )
490
+ return ir .IntegerType .get_signless (t .width )
491
+
492
+
493
+ def _expand_atomic_fp_min_max (
494
+ atomic_type : primitives .AtomicOpType ,
495
+ ptr : ir .Value ,
496
+ val : ir .Value ,
497
+ mask : ir .Value | None = None ,
498
+ semantic : tt_dialect .MemSemantic = tt_dialect .MemSemantic .ACQUIRE_RELEASE ,
499
+ sync_scope : tt_dialect .MemSyncScope = tt_dialect .MemSyncScope .GPU ,
500
+ ) -> ir .Value :
501
+ """
502
+ Expands floating point min/max via sequence of integer min/max. Does not handle NaNs.
503
+
504
+ min:
505
+ return atomic_smin(i_ptr, i_val) if i_val >= 0 else atomic_umax(i_ptr, i_val)
506
+
507
+ max:
508
+ return atomic_smax(i_ptr, i_val) if i_val >= 0 else atomic_umin(i_ptr, i_val)
509
+
510
+ """
511
+
512
+ if ir .RankedTensorType .isinstance (ptr .type ):
513
+ ptr_type = ir .RankedTensorType (ptr .type )
514
+ element_type = tt_dialect .PointerType (ptr_type .element_type )
515
+ result_type = ir .RankedTensorType .get (
516
+ ptr_type .shape , element_type .pointee_type , ptr_type .encoding
517
+ )
518
+ else :
519
+ result_type = tt_dialect .PointerType (ptr .type ).pointee_type
520
+
521
+ ptr_cast = tt_dialect .bitcast (_fp_bits_type (ptr .type ), ptr )
522
+ val_cast = tt_dialect .bitcast (_fp_bits_type (val .type ), val )
523
+
524
+ zero = _full (val_cast .type , 0 )
525
+ pos_cmp = _greater_equal (val_cast , zero , signed = True )
526
+ neg_cmp = _less_than (val_cast , zero , signed = True )
527
+
528
+ pos_mask = pos_cmp if mask is None else arith_dialect .andi (mask , pos_cmp )
529
+ neg_mask = neg_cmp if mask is None else arith_dialect .andi (mask , neg_cmp )
530
+
531
+ pos_op , neg_op = (
532
+ (tt_dialect .RMWOp .MAX , tt_dialect .RMWOp .UMIN )
533
+ if atomic_type == primitives .AtomicOpType .MAX
534
+ else (tt_dialect .RMWOp .MIN , tt_dialect .RMWOp .UMAX )
535
+ )
536
+ # Taken from triton's python/triton/language/semantic.py
537
+ # Doesn't handle nans
538
+ # TODO: Check what Pallas sematics should be
539
+ pos_val = _atomic_rmw (
540
+ pos_op , ptr_cast , val_cast , mask = pos_mask , semantic = semantic , sync_scope = sync_scope
541
+ )
542
+ neg_val = _atomic_rmw (
543
+ neg_op , ptr_cast , val_cast , mask = neg_mask , semantic = semantic , sync_scope = sync_scope
544
+ )
545
+ result = arith_dialect .select (pos_cmp , pos_val , neg_val )
546
+ return tt_dialect .bitcast (result_type , result )
477
547
478
548
@register_lowering (primitives .atomic_rmw_p )
479
549
def _atomic_lowering_rule (
@@ -501,9 +571,23 @@ def _atomic_lowering_rule(
501
571
else :
502
572
op = tt_dialect .RMWOp .FADD
503
573
elif atomic_type == primitives .AtomicOpType .MIN :
504
- op = tt_dialect .RMWOp .MIN
574
+ if isinstance (val .type , ir .IntegerType ):
575
+ op = (
576
+ tt_dialect .RMWOp .MIN
577
+ if jnp .issubdtype (value_aval .dtype , jnp .signedinteger )
578
+ else tt_dialect .RMWOp .UMIN
579
+ )
580
+ else :
581
+ return _expand_atomic_fp_min_max (atomic_type , ptr , val , mask = mask )
505
582
elif atomic_type == primitives .AtomicOpType .MAX :
506
- op = tt_dialect .RMWOp .MAX
583
+ if isinstance (val .type , ir .IntegerType ):
584
+ op = (
585
+ tt_dialect .RMWOp .MAX
586
+ if jnp .issubdtype (value_aval .dtype , jnp .signedinteger )
587
+ else tt_dialect .RMWOp .UMAX
588
+ )
589
+ else :
590
+ return _expand_atomic_fp_min_max (atomic_type , ptr , val , mask = mask )
507
591
elif atomic_type == primitives .AtomicOpType .AND :
508
592
op = tt_dialect .RMWOp .AND
509
593
elif atomic_type == primitives .AtomicOpType .OR :
0 commit comments