Skip to content

Commit

Permalink
[torchlib] Implement aten::prelu (#1728)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Jul 11, 2024
1 parent 54537e4 commit c06e7ab
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
13 changes: 11 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6565,10 +6565,19 @@ def aten_pow(self: TReal, exponent: TTensor) -> TReal:
return op.Pow(self, exponent)


def aten_prelu(self: TensorType, weight: TensorType) -> TensorType:
@torch_op(("aten::prelu", "aten::_prelu_kernel"), trace_only=True)
def aten_prelu(self: TReal, weight: TReal) -> TReal:
"""prelu(Tensor self, Tensor weight) -> Tensor"""

raise NotImplementedError()
zero = op.CastLike(0, self)
rank = len(self.shape)
if rank == 0:
# e.g. self: [], weight: [1]
weight = op.Squeeze(weight)
elif rank >= 2:
# e.g. self: [5,10,5], weight: [10]
weight = op.Reshape(weight, [1, -1] + [1] * (rank - 2))
return op.Add(op.Max(self, zero), op.Mul(weight, op.Min(self, zero)))


def aten_prelu_backward(
Expand Down
3 changes: 2 additions & 1 deletion tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import onnxscript
import onnxscript.evaluator
from onnxscript import ir
from onnxscript.function_libs.torch_lib import graph_building
from tests.function_libs.torch_lib import error_reproduction

Expand Down Expand Up @@ -538,7 +539,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
onnx.checker.check_model(onnx_model, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
raise AssertionError(
f"ONNX model is invalid. Model:\n{onnx.printer.to_text(onnx_model)}"
f"ONNX model is invalid. Model:\n{ir.serde.deserialize_model(onnx_model)}"
) from e

try:
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,7 @@ def _where_input_wrangler(
),
TorchLibOpInfo("polar", core_ops.aten_polar),
TorchLibOpInfo("pow", core_ops.aten_pow),
TorchLibOpInfo("nn.functional.prelu", core_ops.aten_prelu),
TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True),
TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True),
TorchLibOpInfo("ops.aten.randint", core_ops.aten_randint, nondeterministic=True),
Expand Down

0 comments on commit c06e7ab

Please sign in to comment.