@@ -77,6 +77,47 @@ at::Tensor custom_sdpa_aten(
7777 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
7878 const std::optional<double > scale);
7979
80+ #ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
81+ Tensor& custom_quantized_sdpa_out_no_context (
82+ const Tensor& q,
83+ const Tensor& k,
84+ const Tensor& v,
85+ const int64_t start_pos,
86+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
87+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
88+ const optional<Tensor> attn_mask,
89+ const double dropout_p,
90+ const bool is_causal,
91+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
92+ const optional<double > scale,
93+ const optional<Tensor> q_zero_points,
94+ const optional<Tensor> q_scales,
95+ const optional<Tensor> k_zero_points,
96+ const optional<Tensor> k_scales,
97+ const optional<Tensor> v_zero_points,
98+ const optional<Tensor> v_scales,
99+ Tensor& output);
100+
101+ at::Tensor custom_quantized_sdpa_aten (
102+ const at::Tensor& q,
103+ const at::Tensor& k,
104+ const at::Tensor& v,
105+ const int64_t start_pos,
106+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
107+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
108+ const std::optional<at::Tensor> attn_mask,
109+ const double dropout_p,
110+ const bool is_causal,
111+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
112+ const std::optional<double > scale,
113+ const std::optional<at::Tensor>& q_zero_points,
114+ const std::optional<at::Tensor>& q_scales,
115+ const std::optional<at::Tensor>& k_zero_points,
116+ const std::optional<at::Tensor>& k_scales,
117+ const std::optional<at::Tensor>& v_zero_points,
118+ const std::optional<at::Tensor>& v_scales);
119+ #endif // ENABLE_CUSTOM_QUANTIZED_SDPA
120+
80121Tensor& update_cache_out_no_context (
81122 const Tensor& value,
82123 Tensor& cache,
@@ -198,6 +239,85 @@ at::Tensor custom_sdpa_aten(
198239 return output;
199240}
200241
242+ #ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
243+ Tensor& custom_quantized_sdpa_out_no_context (
244+ const Tensor& q,
245+ const Tensor& k,
246+ const Tensor& v,
247+ const int64_t start_pos,
248+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
249+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
250+ const optional<Tensor> attn_mask,
251+ const double dropout_p,
252+ const bool is_causal,
253+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
254+ const optional<double > scale,
255+ const optional<Tensor> q_zero_points,
256+ const optional<Tensor> q_scales,
257+ const optional<Tensor> k_zero_points,
258+ const optional<Tensor> k_scales,
259+ const optional<Tensor> v_zero_points,
260+ const optional<Tensor> v_scales,
261+ Tensor& output) {
262+ executorch::aten::RuntimeContext context{};
263+ return torch::executor::native::custom_quantized_sdpa_out (
264+ context,
265+ q,
266+ k,
267+ v,
268+ start_pos,
269+ attn_mask,
270+ dropout_p,
271+ is_causal,
272+ scale,
273+ q_zero_points,
274+ q_scales,
275+ k_zero_points,
276+ k_scales,
277+ v_zero_points,
278+ v_scales,
279+ output);
280+ }
281+
282+ at::Tensor custom_quantized_sdpa_aten (
283+ const at::Tensor& q,
284+ const at::Tensor& k,
285+ const at::Tensor& v,
286+ const int64_t start_pos,
287+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
288+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
289+ const std::optional<at::Tensor> attn_mask,
290+ const double dropout_p,
291+ const bool is_causal,
292+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
293+ const std::optional<double > scale,
294+ const std::optional<at::Tensor>& q_zero_points,
295+ const std::optional<at::Tensor>& q_scales,
296+ const std::optional<at::Tensor>& k_zero_points,
297+ const std::optional<at::Tensor>& k_scales,
298+ const std::optional<at::Tensor>& v_zero_points,
299+ const std::optional<at::Tensor>& v_scales) {
300+ auto output = at::empty (q.sizes ());
301+ WRAP_TO_ATEN (custom_quantized_sdpa_out_no_context, 14 )
302+ (q,
303+ k,
304+ v,
305+ start_pos,
306+ attn_mask,
307+ dropout_p,
308+ is_causal,
309+ scale,
310+ q_zero_points,
311+ q_scales,
312+ k_zero_points,
313+ k_scales,
314+ v_zero_points,
315+ v_scales,
316+ output);
317+ return output;
318+ }
319+ #endif // ENABLE_CUSTOM_QUANTIZED_SDPA
320+
201321Tensor& update_cache_out_no_context (
202322 const Tensor& value,
203323 Tensor& cache,
@@ -245,6 +365,20 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
245365 m.def (
246366 " update_cache.out(Tensor value, Tensor(a!) cache, "
247367 " SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)" );
368+ #ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
369+ m.def (
370+ " custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
371+ " Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
372+ " float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
373+ " Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
374+ " Tensor? v_scales=None) -> Tensor" );
375+ m.def (
376+ " custom_quantized_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
377+ " Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
378+ " float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
379+ " Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
380+ " Tensor? v_scales=None, *, Tensor(a!) out) -> Tensor(a!)" );
381+ #endif // ENABLE_CUSTOM_QUANTIZED_SDPA
248382}
249383
250384// TODO: Rename this file to op_custom_ops_aot.cpp
@@ -263,4 +397,13 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
263397 m.impl (
264398 " update_cache.out" ,
265399 WRAP_TO_ATEN (torch::executor::native::update_cache_out_no_context, 3 ));
400+ #ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
401+ m.impl (
402+ " custom_quantized_sdpa" ,
403+ torch::executor::native::custom_quantized_sdpa_aten);
404+ m.impl (
405+ " custom_quantized_sdpa.out" ,
406+ WRAP_TO_ATEN (
407+ torch::executor::native::custom_quantized_sdpa_out_no_context, 14 ));
408+ #endif // ENABLE_CUSTOM_QUANTIZED_SDPA
266409}
0 commit comments