diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 60bfc678c409d..40d89eb0552fb 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2475,6 +2475,15 @@ def fn(x): (torch.randn([16, 16]),), ) + def test_tan(self): + def fn(x): + return aten.tan(x) + 2, aten.tan(x + 1) + + self.common( + fn, + (torch.randn([16, 16]),), + ) + def test_tanh(self): def fn(x): return aten.tanh(x) + 2, aten.tanh(x + 1) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 190cb18f83881..b291180bb7773 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -310,6 +310,10 @@ def logical_and(a, b): def logical_or(a, b): return f"{a} || {b}" + @staticmethod + def tan(a): + return f"{a}.tan()" + @staticmethod def tanh(a): vec_one = f"decltype({a})(1)" @@ -454,6 +458,10 @@ def rsqrt(x): def log1p(x): return f"std::log1p({x})" + @staticmethod + def tan(x): + return f"std::tan({x})" + @staticmethod def tanh(x): return f"std::tanh({x})" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 5de2775426e94..762db9f88cebf 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -251,6 +251,10 @@ def rsqrt(x): def log1p(x): return f"tl.libdevice.log1p({x})" + @staticmethod + def tan(x): + return f"tl.libdevice.tan({x})" + @staticmethod def tanh(x): return f"tl.libdevice.tanh({x})" diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index b60c377e5e872..b799280d296e5 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3595,6 +3595,11 @@ def sum_(x, axis=None, keepdims=False, *, dtype=None): type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) +register_pointwise( + aten.tan, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) + register_pointwise( aten.tanh, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,