@@ -60,8 +60,43 @@ using op_call_result =
6060
6161template  <
6262    typename  CTYPE_COMMON,
63+     typename  CTYPE_OUT,
6364    typename  Op,
64-   typename ... Args>
65+     typename ... Args>
66+ inline  void  dtype_specialized_elementwise_fn_impl (
67+     const  Op& compute_fun,
68+     KernelRuntimeContext& ctx,
69+     const  Tensor& out,
70+     Args... inputs) {
71+   constexpr  auto  kNumInputs  = sizeof ...(inputs);
72+   ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMMON)) && ...));
73+ 
74+   std::array<const  CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
75+       inputs.first ->template  const_data_ptr <CTYPE_COMMON>()...};
76+ 
77+   CTYPE_OUT* const  data_out = out.mutable_data_ptr <CTYPE_OUT>();
78+ 
79+   ::executorch::extension::parallel_for (
80+       0 ,
81+       out.numel(),
82+       ::executorch::extension::internal::GRAIN_SIZE,
83+       [&](const  auto  begin, const  auto  end) {
84+         const  auto  range =
85+             BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...);
86+         auto  begin_it = range.begin ();
87+         begin_it += begin;
88+         for  (; (*begin_it)[0 ] < end; ++begin_it) {
89+           const  auto & indexes = *begin_it;
90+           std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
91+           for  (const  auto  idx : c10::irange (kNumInputs )) {
92+             loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1 ]];
93+           }
94+           data_out[indexes[0 ]] = std::apply (compute_fun, loaded_inputs);
95+         }
96+       });
97+ }
98+ 
99+ template  <typename  CTYPE_COMMON, typename  Op, typename ... Args>
65100inline  bool  validate_elementwise_fn_inputs (
66101    const  Op& compute_fun,
67102    KernelRuntimeContext& ctx,
@@ -80,7 +115,8 @@ inline bool validate_elementwise_fn_inputs(
80115      ctx,
81116      (check_input_dtype (inputs, compute_type) && ...) &&
82117          internal::check_tensor_dtype (out, out_dtypes, compute_type),
83-       InvalidArgument, false );
118+       InvalidArgument,
119+       false );
84120
85121  return  true ;
86122}
@@ -90,22 +126,12 @@ template <
90126    const  char * op_name,
91127    typename  Op,
92128    typename ... Args>
93- inline  void  apply_elementwise_fn (
129+ inline  void  apply_elementwise_fn_generic_impl (
94130    const  Op& compute_fun,
95131    KernelRuntimeContext& ctx,
96132    const  Tensor& out,
97133    SupportedTensorDtypes out_dtypes,
98134    Args... inputs) {
99-   const  bool  inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
100-       compute_fun,
101-       ctx,
102-       out,
103-       out_dtypes,
104-       inputs...);
105-   if  (!inputs_valid) {
106-     return ;
107-   }
108- 
109135  constexpr  auto  kNumInputs  = sizeof ...(inputs);
110136
111137  struct  InputInfo  {
@@ -157,6 +183,63 @@ inline void apply_elementwise_fn(
157183        }
158184      });
159185}
186+ 
187+ template  <
188+     typename  CTYPE_COMMON,
189+     const  char * op_name,
190+     typename  Op,
191+     typename ... Args>
192+ inline  void  apply_elementwise_fn_runtime_out_dtypes (
193+     const  Op& compute_fun,
194+     KernelRuntimeContext& ctx,
195+     const  Tensor& out,
196+     SupportedTensorDtypes out_dtypes,
197+     Args... inputs) {
198+   const  bool  inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
199+       compute_fun, ctx, out, out_dtypes, inputs...);
200+   if  (!inputs_valid) {
201+     return ;
202+   }
203+ 
204+   apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
205+       compute_fun, ctx, out, out_dtypes, inputs...);
206+ }
207+ 
208+ template  <
209+     typename  CTYPE_COMMON,
210+     const  char * op_name,
211+     SupportedTensorDtypes out_dtypes,
212+     typename  Op,
213+     typename ... Args>
214+ inline  void  apply_elementwise_fn (
215+     const  Op& compute_fun,
216+     KernelRuntimeContext& ctx,
217+     const  Tensor& out,
218+     Args... inputs) {
219+   const  bool  inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
220+       compute_fun, ctx, out, out_dtypes, inputs...);
221+   if  (!inputs_valid) {
222+     return ;
223+   }
224+ 
225+   constexpr  auto  compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
226+   const  bool  all_inputs_compute_dtype =
227+       ((inputs.first ->scalar_type () == compute_type) && ...);
228+ 
229+   constexpr  ScalarType out_specialized_scalar_type =
230+       specialized_output_scalar_type<CTYPE_COMMON>(out_dtypes);
231+   if  (all_inputs_compute_dtype &&
232+       out.scalar_type () == out_specialized_scalar_type) {
233+     using  CTYPE_OUT =
234+         typename  ScalarTypeToCppType<out_specialized_scalar_type>::type;
235+     dtype_specialized_elementwise_fn_impl<CTYPE_COMMON, CTYPE_OUT>(
236+         compute_fun, ctx, out, inputs...);
237+     return ;
238+   }
239+ 
240+   apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
241+       compute_fun, ctx, out, out_dtypes, inputs...);
242+ }
160243} //  namespace internal
161244
162245// / DEPRECATED: prefer the variant with out_dtypes in the template argument.
@@ -168,19 +251,23 @@ inline void apply_unitensor_elementwise_fn(
168251    SupportedTensorDtypes a_dtypes,
169252    const  Tensor& out,
170253    SupportedTensorDtypes out_dtypes) {
171-   internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
254+   internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
172255      compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
173256}
174257
175- template  <typename  CTYPE_COMMON, const  char * op_name, SupportedTensorDtypes out_dtypes, typename  Op>
258+ template  <
259+     typename  CTYPE_COMMON,
260+     const  char * op_name,
261+     SupportedTensorDtypes out_dtypes,
262+     typename  Op>
176263inline  void  apply_unitensor_elementwise_fn (
177264    const  Op& compute_fun,
178265    KernelRuntimeContext& ctx,
179266    const  Tensor& a,
180267    SupportedTensorDtypes a_dtypes,
181268    const  Tensor& out) {
182-   internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
183-       compute_fun, ctx, out, out_dtypes,  std::make_pair (&a, a_dtypes));
269+   internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
270+       compute_fun, ctx, out, std::make_pair (&a, a_dtypes));
184271}
185272
186273/* *
@@ -196,7 +283,7 @@ inline void apply_bitensor_elementwise_fn(
196283    SupportedTensorDtypes b_dtypes,
197284    const  Tensor& out,
198285    SupportedTensorDtypes out_dtypes) {
199-   internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
286+   internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
200287      compute_fun,
201288      ctx,
202289      out,
@@ -210,7 +297,11 @@ inline void apply_bitensor_elementwise_fn(
210297 * perform a computation and write to the corresponding element of the output. 
211298 * Tensor broadcasting is applied wherever it is required. 
212299 */  
213- template  <typename  CTYPE_COMMON, const  char * op_name, SupportedTensorDtypes out_dtypes, typename  Op>
300+ template  <
301+     typename  CTYPE_COMMON,
302+     const  char * op_name,
303+     SupportedTensorDtypes out_dtypes,
304+     typename  Op>
214305inline  void  apply_bitensor_elementwise_fn (
215306    const  Op& compute_fun,
216307    KernelRuntimeContext& ctx,
@@ -219,11 +310,10 @@ inline void apply_bitensor_elementwise_fn(
219310    const  Tensor& b,
220311    SupportedTensorDtypes b_dtypes,
221312    const  Tensor& out) {
222-   internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
313+   internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
223314      compute_fun,
224315      ctx,
225316      out,
226-       out_dtypes,
227317      std::make_pair (&a, a_dtypes),
228318      std::make_pair (&b, b_dtypes));
229319}
@@ -243,7 +333,7 @@ inline void apply_tritensor_elementwise_fn(
243333    SupportedTensorDtypes c_dtypes,
244334    const  Tensor& out,
245335    SupportedTensorDtypes out_dtypes) {
246-   internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
336+   internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
247337      compute_fun,
248338      ctx,
249339      out,
@@ -273,7 +363,11 @@ inline void apply_tritensor_elementwise_fn(
273363 * static constexpr const char op_name[] = "my_op"; 
274364 * apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>. 
275365 */  
276- template  <typename  CTYPE_COMMON, const  char * op_name, SupportedTensorDtypes out_dtypes, typename  Op>
366+ template  <
367+     typename  CTYPE_COMMON,
368+     const  char * op_name,
369+     SupportedTensorDtypes out_dtypes,
370+     typename  Op>
277371inline  void  apply_tritensor_elementwise_fn (
278372    const  Op& compute_fun,
279373    KernelRuntimeContext& ctx,
@@ -284,11 +378,10 @@ inline void apply_tritensor_elementwise_fn(
284378    const  Tensor& c,
285379    SupportedTensorDtypes c_dtypes,
286380    const  Tensor& out) {
287-   internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
381+   internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
288382      compute_fun,
289383      ctx,
290384      out,
291-       out_dtypes,
292385      std::make_pair (&a, a_dtypes),
293386      std::make_pair (&b, b_dtypes),
294387      std::make_pair (&c, c_dtypes));
0 commit comments