@@ -51,6 +51,44 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5151}
5252
5353namespace internal {
54+ template <
55+ typename CTYPE_COMPUTE,
56+ typename CTYPE_OUT,
57+ typename Op,
58+ typename ... Args>
59+ inline void dtype_specialized_elementwise_fn_impl (
60+ const Op& compute_fun,
61+ KernelRuntimeContext& ctx,
62+ const Tensor& out,
63+ Args... inputs) {
64+ constexpr auto kNumInputs = sizeof ...(inputs);
65+ ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMPUTE)) && ...));
66+
67+ ::executorch::extension::parallel_for (
68+ 0 ,
69+ out.numel(),
70+ ::executorch::extension::internal::GRAIN_SIZE,
71+ [&](const auto begin, const auto end) {
72+ std::array<const CTYPE_COMPUTE*, kNumInputs > inputs_data_ptrs = {
73+ inputs.first ->template const_data_ptr <CTYPE_COMPUTE>()...};
74+
75+ CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
76+
77+ const auto range =
78+ BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...);
79+ auto begin_it = range.begin ();
80+ begin_it += begin;
81+ for (; (*begin_it)[0 ] < end; ++begin_it) {
82+ const auto & indexes = *begin_it;
83+ std::array<CTYPE_COMPUTE, kNumInputs > loaded_inputs;
84+ for (const auto idx : c10::irange (kNumInputs )) {
85+ loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1 ]];
86+ }
87+ data_out[indexes[0 ]] = std::apply (compute_fun, loaded_inputs);
88+ }
89+ });
90+ }
91+
5492template <typename CTYPE_COMPUTE, typename Op, typename ... Args>
5593inline bool validate_elementwise_fn_inputs (
5694 const Op& compute_fun,
@@ -81,18 +119,12 @@ template <
81119 const char * op_name,
82120 typename Op,
83121 typename ... Args>
84- inline void apply_elementwise_fn (
122+ inline void apply_elementwise_fn_generic_impl (
85123 const Op& compute_fun,
86124 KernelRuntimeContext& ctx,
87125 const Tensor& out,
88126 SupportedTensorDtypes out_dtypes,
89127 Args... inputs) {
90- const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
91- compute_fun, ctx, out, out_dtypes, inputs...);
92- if (!inputs_valid) {
93- return ;
94- }
95-
96128 constexpr auto kNumInputs = sizeof ...(inputs);
97129
98130 struct InputInfo {
@@ -138,6 +170,63 @@ inline void apply_elementwise_fn(
138170 });
139171}
140172
173+ template <
174+ typename CTYPE_COMPUTE,
175+ const char * op_name,
176+ typename Op,
177+ typename ... Args>
178+ inline void apply_elementwise_fn_runtime_out_dtypes (
179+ const Op& compute_fun,
180+ KernelRuntimeContext& ctx,
181+ const Tensor& out,
182+ SupportedTensorDtypes out_dtypes,
183+ Args... inputs) {
184+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
185+ compute_fun, ctx, out, out_dtypes, inputs...);
186+ if (!inputs_valid) {
187+ return ;
188+ }
189+
190+ apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>(
191+ compute_fun, ctx, out, out_dtypes, inputs...);
192+ }
193+
194+ template <
195+ typename CTYPE_COMPUTE,
196+ const char * op_name,
197+ SupportedTensorDtypes out_dtypes,
198+ typename Op,
199+ typename ... Args>
200+ inline void apply_elementwise_fn (
201+ const Op& compute_fun,
202+ KernelRuntimeContext& ctx,
203+ const Tensor& out,
204+ Args... inputs) {
205+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
206+ compute_fun, ctx, out, out_dtypes, inputs...);
207+ if (!inputs_valid) {
208+ return ;
209+ }
210+
211+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
212+ const bool all_inputs_compute_dtype =
213+ ((inputs.first ->scalar_type () == compute_type) && ...);
214+
215+ constexpr ScalarType out_specialized_scalar_type =
216+ specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes);
217+ if (all_inputs_compute_dtype &&
218+ out.scalar_type () == out_specialized_scalar_type) {
219+ using CTYPE_OUT =
220+ typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
221+ dtype_specialized_elementwise_fn_impl<CTYPE_COMPUTE, CTYPE_OUT>(
222+ compute_fun, ctx, out, inputs...);
223+ return ;
224+ }
225+
226+ apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>(
227+ compute_fun, ctx, out, out_dtypes, inputs...);
228+ }
229+
141230// / DEPRECATED: prefer the variant with out_dtypes in the template argument.
142231template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
143232inline void apply_unitensor_elementwise_fn (
@@ -147,7 +236,7 @@ inline void apply_unitensor_elementwise_fn(
147236 SupportedTensorDtypes a_dtypes,
148237 const Tensor& out,
149238 SupportedTensorDtypes out_dtypes) {
150- internal::apply_elementwise_fn <CTYPE_COMPUTE, op_name>(
239+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMPUTE, op_name>(
151240 compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
152241}
153242
@@ -162,8 +251,8 @@ inline void apply_unitensor_elementwise_fn(
162251 const Tensor& a,
163252 SupportedTensorDtypes a_dtypes,
164253 const Tensor& out) {
165- internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
166- compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
254+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes >(
255+ compute_fun, ctx, out, std::make_pair (&a, a_dtypes));
167256}
168257
169258/* *
@@ -179,7 +268,7 @@ inline void apply_bitensor_elementwise_fn(
179268 SupportedTensorDtypes b_dtypes,
180269 const Tensor& out,
181270 SupportedTensorDtypes out_dtypes) {
182- internal::apply_elementwise_fn <CTYPE_COMPUTE, op_name>(
271+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMPUTE, op_name>(
183272 compute_fun,
184273 ctx,
185274 out,
@@ -206,11 +295,10 @@ inline void apply_bitensor_elementwise_fn(
206295 const Tensor& b,
207296 SupportedTensorDtypes b_dtypes,
208297 const Tensor& out) {
209- internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
298+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes >(
210299 compute_fun,
211300 ctx,
212301 out,
213- out_dtypes,
214302 std::make_pair (&a, a_dtypes),
215303 std::make_pair (&b, b_dtypes));
216304}
@@ -230,7 +318,7 @@ inline void apply_tritensor_elementwise_fn(
230318 SupportedTensorDtypes c_dtypes,
231319 const Tensor& out,
232320 SupportedTensorDtypes out_dtypes) {
233- internal::apply_elementwise_fn <CTYPE_COMPUTE, op_name>(
321+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMPUTE, op_name>(
234322 compute_fun,
235323 ctx,
236324 out,
@@ -275,11 +363,10 @@ inline void apply_tritensor_elementwise_fn(
275363 const Tensor& c,
276364 SupportedTensorDtypes c_dtypes,
277365 const Tensor& out) {
278- internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
366+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes >(
279367 compute_fun,
280368 ctx,
281369 out,
282- out_dtypes,
283370 std::make_pair (&a, a_dtypes),
284371 std::make_pair (&b, b_dtypes),
285372 std::make_pair (&c, c_dtypes));
0 commit comments