Skip to content

Commit

Permalink
[Inductor] Lower aten.tan (pytorch#92837)
Browse files Browse the repository at this point in the history
  • Loading branch information
min-jean-cho authored and pytorchmergebot committed Jan 24, 2023
1 parent 19c9b09 commit 68a40a4
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 0 deletions.
9 changes: 9 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2475,6 +2475,15 @@ def fn(x):
(torch.randn([16, 16]),),
)

def test_tan(self):
def fn(x):
return aten.tan(x) + 2, aten.tan(x + 1)

self.common(
fn,
(torch.randn([16, 16]),),
)

def test_tanh(self):
def fn(x):
return aten.tanh(x) + 2, aten.tanh(x + 1)
Expand Down
8 changes: 8 additions & 0 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ def logical_and(a, b):
def logical_or(a, b):
return f"{a} || {b}"

@staticmethod
def tan(a):
return f"{a}.tan()"

@staticmethod
def tanh(a):
vec_one = f"decltype({a})(1)"
Expand Down Expand Up @@ -454,6 +458,10 @@ def rsqrt(x):
def log1p(x):
return f"std::log1p({x})"

@staticmethod
def tan(x):
return f"std::tan({x})"

@staticmethod
def tanh(x):
return f"std::tanh({x})"
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def rsqrt(x):
def log1p(x):
return f"tl.libdevice.log1p({x})"

@staticmethod
def tan(x):
return f"tl.libdevice.tan({x})"

@staticmethod
def tanh(x):
return f"tl.libdevice.tanh({x})"
Expand Down
5 changes: 5 additions & 0 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3595,6 +3595,11 @@ def sum_(x, axis=None, keepdims=False, *, dtype=None):
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)

register_pointwise(
aten.tan,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)

register_pointwise(
aten.tanh,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
Expand Down

0 comments on commit 68a40a4

Please sign in to comment.