From 88e9737fe5b3f8e25dcf2a54cae2c1146dfa80ae Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Tue, 11 Jun 2024 15:22:57 -0700 Subject: [PATCH] Change the elementwise broadcasting contract from graph to kernel (#3894) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/3894 Currently, there is a graph level pass to handle limited broadcasting of elementwise ops if the input tensors are not of the same size. We move this responsibility down to the kernels with this diff, which is how ET and the portable ops do it. Ops of this kind are only `add`, `sub`, `mul` and `div` for now, but there will be more. We retain the implementations for the reference kernels, because we want to avoid linking the portable ops directly, which takes forever at compile time. We can also use a much smaller set of types (basically only `float`). We can remove a hack in the RNNT Joiner with this change, and run it natively. It takes a huge hit in performance, which will be fixed by getting broadcast-friendly kernels from Cadence. We finally remove the binop tests in `test_aten_ops.py`, which were also using strange types and had been on the chopping block for a while. We also remove the rSubScalar test, since we don't trace to rsub anymore. Reviewed By: dulinriley Differential Revision: D58207691 fbshipit-source-id: 12ec6df3b37eaea7ec1a7327c39a0a659bfaf1c0 --- kernels/portable/cpu/util/targets.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 688cd0dd92..bd55b4da30 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -73,7 +73,7 @@ def define_common_targets(): "//executorch/runtime/kernel:kernel_includes", "//executorch/runtime/core/exec_aten/util:tensor_util", ], - visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."], + visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"], ) runtime.cxx_library(