Skip to content

Commit fa746e3

Browse files
eellisonpytorchmergebot
authored andcommitted
[Easy] factor out inductor ophandler decompositions (pytorch#142400)
Factor out inductor operator decompositions Pull Request resolved: pytorch#142400 Approved by: https://github.com/Chillee, https://github.com/jansel ghstack dependencies: pytorch#134532, pytorch#142350
1 parent 1fb3d5a commit fa746e3

File tree

2 files changed

+80
-75
lines changed

2 files changed

+80
-75
lines changed

torch/_inductor/codegen/common.py

Lines changed: 79 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -613,51 +613,16 @@ def doprint(self, expr, *, simplify: bool = True, p=True):
613613
return super().doprint(expr)
614614

615615

616-
class OpOverrides:
617-
def __init__(self, parent):
618-
super().__init__()
619-
self._parent = parent
620-
621-
@staticmethod
622-
def paren(string: str) -> str:
623-
def all_in_parens(string: str) -> bool:
624-
if string[0] != "(" or len(string) < 2:
625-
return False
626-
count = 1
627-
for i, char in enumerate(string[1:]):
628-
if char == "(":
629-
count += 1
630-
elif char == ")":
631-
count -= 1
632-
if count == 0 and i != len(string) - 2:
633-
return False
634-
assert count == 0
635-
return True
636-
637-
if (
638-
isinstance(string, CSEVariable)
639-
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
640-
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
641-
or string == ""
642-
):
643-
return string
644-
# don't put extra parens for strings that are already wrapped in parens
645-
if all_in_parens(string):
646-
return string
647-
return f"({string})"
648-
649-
def __getattr__(self, item):
650-
return getattr(self._parent, item)
616+
class OpDecompositions:
617+
"""
618+
Decomposes inductor ops
619+
"""
651620

652621
@staticmethod
653622
def identity(value):
654623
# used to trigger cse
655624
return value
656625

657-
@staticmethod
658-
def constant(value, dtype):
659-
return repr(value)
660-
661626
@staticmethod
662627
def reciprocal(x):
663628
return ops.truediv(ops.constant(1, torch.int32), x)
@@ -699,15 +664,86 @@ def sigmoid(x):
699664
one = ops.constant(1, torch.int32)
700665
return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
701666

667+
@staticmethod
668+
def relu(x):
669+
return ops.maximum(x, ops.constant(0, torch.int32))
670+
671+
@staticmethod
672+
def fma(x, y, z):
673+
# for backends that don't override this (halide)
674+
return ops.add(ops.mul(x, y), z)
675+
676+
@staticmethod
677+
def floor_to_int(a, dtype):
678+
return ops.to_dtype(ops.floor(a), dtype)
679+
680+
@staticmethod
681+
def ceil_to_int(a, dtype):
682+
return ops.to_dtype(ops.ceil(a), dtype)
683+
684+
@staticmethod
685+
def trunc_to_int(a, dtype):
686+
return ops.to_dtype(ops.trunc(a), dtype)
687+
688+
@staticmethod
689+
def remainder(a, b):
690+
r = ops.mod(a, b)
691+
cond = ops.and_(
692+
ops.ne(r, ops.constant(0, torch.int32)),
693+
ops.ne(ops.signbit(r), ops.signbit(b)),
694+
)
695+
return ops.where(cond, ops.add(r, b), r)
696+
697+
@staticmethod
698+
def round_to_int(a, dtype):
699+
return ops.to_dtype(ops.round(a), dtype)
700+
701+
702+
class OpOverrides(OpDecompositions):
703+
def __init__(self, parent):
704+
super().__init__()
705+
self._parent = parent
706+
707+
@staticmethod
708+
def paren(string: str) -> str:
709+
def all_in_parens(string: str) -> bool:
710+
if string[0] != "(" or len(string) < 2:
711+
return False
712+
count = 1
713+
for i, char in enumerate(string[1:]):
714+
if char == "(":
715+
count += 1
716+
elif char == ")":
717+
count -= 1
718+
if count == 0 and i != len(string) - 2:
719+
return False
720+
assert count == 0
721+
return True
722+
723+
if (
724+
isinstance(string, CSEVariable)
725+
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
726+
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
727+
or string == ""
728+
):
729+
return string
730+
# don't put extra parens for strings that are already wrapped in parens
731+
if all_in_parens(string):
732+
return string
733+
return f"({string})"
734+
735+
def __getattr__(self, item):
736+
return getattr(self._parent, item)
737+
738+
@staticmethod
739+
def constant(value, dtype):
740+
return repr(value)
741+
702742
@staticmethod
703743
def libdevice_sigmoid(x):
704744
one = ops.constant(1, torch.int32)
705745
return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
706746

707-
@staticmethod
708-
def relu(x):
709-
return ops.maximum(x, ops.constant(0, torch.int32))
710-
711747
@staticmethod
712748
def libdevice_abs(x):
713749
return ops.abs(x)
@@ -760,36 +796,6 @@ def bitwise_left_shift(x, y):
760796
def bitwise_right_shift(x, y):
761797
return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
762798

763-
@staticmethod
764-
def remainder(a, b):
765-
r = ops.mod(a, b)
766-
cond = ops.and_(
767-
ops.ne(r, ops.constant(0, torch.int32)),
768-
ops.ne(ops.signbit(r), ops.signbit(b)),
769-
)
770-
return ops.where(cond, ops.add(r, b), r)
771-
772-
@staticmethod
773-
def fma(x, y, z):
774-
# for backends that don't override this (halide)
775-
return ops.add(ops.mul(x, y), z)
776-
777-
@staticmethod
778-
def trunc_to_int(a, dtype):
779-
return ops.to_dtype(ops.trunc(a), dtype)
780-
781-
@staticmethod
782-
def floor_to_int(a, dtype):
783-
return ops.to_dtype(ops.floor(a), dtype)
784-
785-
@staticmethod
786-
def ceil_to_int(a, dtype):
787-
return ops.to_dtype(ops.ceil(a), dtype)
788-
789-
@staticmethod
790-
def round_to_int(a, dtype):
791-
return ops.to_dtype(ops.round(a), dtype)
792-
793799
@staticmethod
794800
def int_truediv(a, b):
795801
# TODO: this is wrong

torch/_inductor/ops_handler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def _arg_str(a) -> str:
5151
# implementations make heavy use of __getattr__ magic, and pre-existing
5252
# stubs for methods would interfere with this mechanism.
5353
#
54-
# TODO: A superclass that does desugaring for operations like
55-
# reciprocal/square might be useful.
54+
# See OpDecompositions for superclass that desugars operations like reciprocal/square
5655
class OpsHandler(Protocol[T]):
5756
"""
5857
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,

0 commit comments

Comments
 (0)