Skip to content

Commit b383d29

Browse files
aakhundovfacebook-github-bot
authored andcommitted
Add elu AIT converter (facebookincubator#953)
Summary: Pull Request resolved: facebookincubator#953 ATT Reviewed By: frank-wei, zoranzhao Differential Revision: D50384768 fbshipit-source-id: c8a9787fb9313e3fcf712212fb27e2d9ad970a63
1 parent 631ab16 commit b383d29

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

fx2ait/fx2ait/converters/ait_converters.py

+25
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,31 @@ def acc_ops_relu(
488488
return elementwise(FuncEnum.RELU)(input_val)
489489

490490

491+
@ait_converter(acc_ops.elu)
492+
def acc_ops_elu(
493+
target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str
494+
) -> ConverterOutput:
495+
input_val = kwargs["input"]
496+
if not isinstance(input_val, AITTensor):
497+
raise RuntimeError(f"Unexpected input for {name}: {input_val}")
498+
499+
inputs = [input_val]
500+
if "alpha" in kwargs:
501+
if not isinstance(kwargs["alpha"], (int, float)):
502+
raise RuntimeError(
503+
f"When specified, alpha in {name} must be a scalar: {input_val}"
504+
)
505+
input_alpha = AITTensor(
506+
shape=[],
507+
dtype=input_val._attrs["dtype"],
508+
name="alpha",
509+
value=kwargs["alpha"],
510+
)
511+
inputs.append(input_alpha)
512+
513+
return elementwise(FuncEnum.ELU)(*inputs)
514+
515+
491516
@ait_converter(acc_ops.leaky_relu)
492517
def acc_ops_leaky_relu(
493518
target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
16+
import torch
17+
from fx2ait.acc_tracer import acc_ops
18+
from fx2ait.tools.common_fx2ait import AITTestCase
19+
20+
21+
class TestEluConverter(AITTestCase):
22+
def test_elu(self):
23+
class TestModule(torch.nn.Module):
24+
def forward(self, x: torch.Tensor) -> torch.Tensor:
25+
return torch.nn.functional.elu(x)
26+
27+
model = TestModule().cuda().half()
28+
inputs = [
29+
torch.randn(2, 3).half().cuda(),
30+
]
31+
32+
self.run_test(model, inputs, expected_ops={acc_ops.elu})
33+
34+
def test_elu_with_alpha(self):
35+
class TestModule(torch.nn.Module):
36+
def forward(self, x: torch.Tensor) -> torch.Tensor:
37+
return torch.nn.functional.elu(x, alpha=3.14)
38+
39+
model = TestModule().cuda().half()
40+
inputs = [
41+
torch.randn(2, 3).half().cuda(),
42+
]
43+
44+
self.run_test(model, inputs, expected_ops={acc_ops.elu})

0 commit comments

Comments
 (0)