diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b3f08447e..dfc0e7882 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6565,10 +6565,19 @@ def aten_pow(self: TReal, exponent: TTensor) -> TReal: return op.Pow(self, exponent) -def aten_prelu(self: TensorType, weight: TensorType) -> TensorType: +@torch_op(("aten::prelu", "aten::_prelu_kernel"), trace_only=True) +def aten_prelu(self: TReal, weight: TReal) -> TReal: """prelu(Tensor self, Tensor weight) -> Tensor""" - raise NotImplementedError() + zero = op.CastLike(0, self) + rank = len(self.shape) + if rank == 0: + # e.g. self: [], weight: [1] + weight = op.Squeeze(weight) + elif rank >= 2: + # e.g. self: [5,10,5], weight: [10] + weight = op.Reshape(weight, [1, -1] + [1] * (rank - 2)) + return op.Add(op.Max(self, zero), op.Mul(weight, op.Min(self, zero))) def aten_prelu_backward( diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index 2064c8b87..3a9717cc3 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -34,6 +34,7 @@ import onnxscript import onnxscript.evaluator +from onnxscript import ir from onnxscript.function_libs.torch_lib import graph_building from tests.function_libs.torch_lib import error_reproduction @@ -538,7 +539,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, onnx.checker.check_model(onnx_model, full_check=True) except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e: raise AssertionError( - f"ONNX model is invalid. Model:\n{onnx.printer.to_text(onnx_model)}" + f"ONNX model is invalid. Model:\n{ir.serde.deserialize_model(onnx_model)}" ) from e try: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b7038ada7..b4f3c5701 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1311,6 +1311,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("polar", core_ops.aten_polar), TorchLibOpInfo("pow", core_ops.aten_pow), + TorchLibOpInfo("nn.functional.prelu", core_ops.aten_prelu), TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True), TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True), TorchLibOpInfo("ops.aten.randint", core_ops.aten_randint, nondeterministic=True),