@@ -613,51 +613,16 @@ def doprint(self, expr, *, simplify: bool = True, p=True):
613
613
return super ().doprint (expr )
614
614
615
615
616
- class OpOverrides :
617
- def __init__ (self , parent ):
618
- super ().__init__ ()
619
- self ._parent = parent
620
-
621
- @staticmethod
622
- def paren (string : str ) -> str :
623
- def all_in_parens (string : str ) -> bool :
624
- if string [0 ] != "(" or len (string ) < 2 :
625
- return False
626
- count = 1
627
- for i , char in enumerate (string [1 :]):
628
- if char == "(" :
629
- count += 1
630
- elif char == ")" :
631
- count -= 1
632
- if count == 0 and i != len (string ) - 2 :
633
- return False
634
- assert count == 0
635
- return True
636
-
637
- if (
638
- isinstance (string , CSEVariable )
639
- or re .match (r"^[a-z0-9_.]+$" , string , re .IGNORECASE )
640
- or re .match (r"^\([^)]*\)$" , string , re .IGNORECASE )
641
- or string == ""
642
- ):
643
- return string
644
- # don't put extra parens for strings that are already wrapped in parens
645
- if all_in_parens (string ):
646
- return string
647
- return f"({ string } )"
648
-
649
- def __getattr__ (self , item ):
650
- return getattr (self ._parent , item )
616
+ class OpDecompositions :
617
+ """
618
+ Decomposes inductor ops
619
+ """
651
620
652
621
@staticmethod
653
622
def identity (value ):
654
623
# used to trigger cse
655
624
return value
656
625
657
- @staticmethod
658
- def constant (value , dtype ):
659
- return repr (value )
660
-
661
626
@staticmethod
662
627
def reciprocal (x ):
663
628
return ops .truediv (ops .constant (1 , torch .int32 ), x )
@@ -699,15 +664,86 @@ def sigmoid(x):
699
664
one = ops .constant (1 , torch .int32 )
700
665
return ops .truediv (one , ops .add (one , ops .exp (ops .neg (x ))))
701
666
667
+ @staticmethod
668
+ def relu (x ):
669
+ return ops .maximum (x , ops .constant (0 , torch .int32 ))
670
+
671
+ @staticmethod
672
+ def fma (x , y , z ):
673
+ # for backends that don't override this (halide)
674
+ return ops .add (ops .mul (x , y ), z )
675
+
676
+ @staticmethod
677
+ def floor_to_int (a , dtype ):
678
+ return ops .to_dtype (ops .floor (a ), dtype )
679
+
680
+ @staticmethod
681
+ def ceil_to_int (a , dtype ):
682
+ return ops .to_dtype (ops .ceil (a ), dtype )
683
+
684
+ @staticmethod
685
+ def trunc_to_int (a , dtype ):
686
+ return ops .to_dtype (ops .trunc (a ), dtype )
687
+
688
+ @staticmethod
689
+ def remainder (a , b ):
690
+ r = ops .mod (a , b )
691
+ cond = ops .and_ (
692
+ ops .ne (r , ops .constant (0 , torch .int32 )),
693
+ ops .ne (ops .signbit (r ), ops .signbit (b )),
694
+ )
695
+ return ops .where (cond , ops .add (r , b ), r )
696
+
697
+ @staticmethod
698
+ def round_to_int (a , dtype ):
699
+ return ops .to_dtype (ops .round (a ), dtype )
700
+
701
+
702
+ class OpOverrides (OpDecompositions ):
703
+ def __init__ (self , parent ):
704
+ super ().__init__ ()
705
+ self ._parent = parent
706
+
707
+ @staticmethod
708
+ def paren (string : str ) -> str :
709
+ def all_in_parens (string : str ) -> bool :
710
+ if string [0 ] != "(" or len (string ) < 2 :
711
+ return False
712
+ count = 1
713
+ for i , char in enumerate (string [1 :]):
714
+ if char == "(" :
715
+ count += 1
716
+ elif char == ")" :
717
+ count -= 1
718
+ if count == 0 and i != len (string ) - 2 :
719
+ return False
720
+ assert count == 0
721
+ return True
722
+
723
+ if (
724
+ isinstance (string , CSEVariable )
725
+ or re .match (r"^[a-z0-9_.]+$" , string , re .IGNORECASE )
726
+ or re .match (r"^\([^)]*\)$" , string , re .IGNORECASE )
727
+ or string == ""
728
+ ):
729
+ return string
730
+ # don't put extra parens for strings that are already wrapped in parens
731
+ if all_in_parens (string ):
732
+ return string
733
+ return f"({ string } )"
734
+
735
+ def __getattr__ (self , item ):
736
+ return getattr (self ._parent , item )
737
+
738
+ @staticmethod
739
+ def constant (value , dtype ):
740
+ return repr (value )
741
+
702
742
@staticmethod
703
743
def libdevice_sigmoid (x ):
704
744
one = ops .constant (1 , torch .int32 )
705
745
return ops .truediv (one , ops .add (one , ops .libdevice_exp (ops .neg (x ))))
706
746
707
- @staticmethod
708
- def relu (x ):
709
- return ops .maximum (x , ops .constant (0 , torch .int32 ))
710
-
711
747
@staticmethod
712
748
def libdevice_abs (x ):
713
749
return ops .abs (x )
@@ -760,36 +796,6 @@ def bitwise_left_shift(x, y):
760
796
def bitwise_right_shift (x , y ):
761
797
return f"{ OpOverrides .paren (x )} >> { OpOverrides .paren (y )} "
762
798
763
- @staticmethod
764
- def remainder (a , b ):
765
- r = ops .mod (a , b )
766
- cond = ops .and_ (
767
- ops .ne (r , ops .constant (0 , torch .int32 )),
768
- ops .ne (ops .signbit (r ), ops .signbit (b )),
769
- )
770
- return ops .where (cond , ops .add (r , b ), r )
771
-
772
- @staticmethod
773
- def fma (x , y , z ):
774
- # for backends that don't override this (halide)
775
- return ops .add (ops .mul (x , y ), z )
776
-
777
- @staticmethod
778
- def trunc_to_int (a , dtype ):
779
- return ops .to_dtype (ops .trunc (a ), dtype )
780
-
781
- @staticmethod
782
- def floor_to_int (a , dtype ):
783
- return ops .to_dtype (ops .floor (a ), dtype )
784
-
785
- @staticmethod
786
- def ceil_to_int (a , dtype ):
787
- return ops .to_dtype (ops .ceil (a ), dtype )
788
-
789
- @staticmethod
790
- def round_to_int (a , dtype ):
791
- return ops .to_dtype (ops .round (a ), dtype )
792
-
793
799
@staticmethod
794
800
def int_truediv (a , b ):
795
801
# TODO: this is wrong
0 commit comments