Skip to content

Commit dc17310

Browse files
committed
optimize kwarg extraction
1 parent bead762 commit dc17310

File tree

2 files changed

+117
-74
lines changed

2 files changed

+117
-74
lines changed

pyo3-macros-backend/src/params.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ pub fn impl_arg_params(
108108
}
109109
});
110110

111+
let parameter_names = positional_parameter_names.iter().chain(
112+
spec.signature
113+
.python_signature
114+
.keyword_only_parameters
115+
.iter()
116+
.map(|(name, _)| name),
117+
);
118+
111119
let num_params = positional_parameter_names.len() + keyword_only_parameters.len();
112120

113121
let mut option_pos = 0usize;
@@ -161,14 +169,20 @@ pub fn impl_arg_params(
161169
// create array of arguments, and then parse
162170
(
163171
quote! {
164-
const DESCRIPTION: #pyo3_path::impl_::extract_argument::FunctionDescription = #pyo3_path::impl_::extract_argument::FunctionDescription {
165-
cls_name: #cls_name,
166-
func_name: stringify!(#python_name),
167-
positional_parameter_names: &[#(#positional_parameter_names),*],
168-
positional_only_parameters: #positional_only_parameters,
169-
required_positional_parameters: #required_positional_parameters,
170-
keyword_only_parameters: &[#(#keyword_only_parameters),*],
171-
};
172+
const PARAMETER_NAMES: &[&str] = &[#(#parameter_names),*];
173+
fn argument_lookup_by_name(name: &str) -> Option<usize> {
174+
PARAMETER_NAMES.iter().position(|&n| n == name)
175+
}
176+
177+
const DESCRIPTION: #pyo3_path::impl_::extract_argument::FunctionDescription = #pyo3_path::impl_::extract_argument::FunctionDescription::new(
178+
#cls_name,
179+
stringify!(#python_name),
180+
&[#(#positional_parameter_names),*],
181+
#positional_only_parameters,
182+
#required_positional_parameters,
183+
&[#(#keyword_only_parameters),*],
184+
argument_lookup_by_name,
185+
);
172186
let mut #args_array = [::std::option::Option::None; #num_params];
173187
let (_args, _kwargs) = #extract_expression;
174188
#from_py_with

src/impl_/extract_argument.rs

Lines changed: 95 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::borrow::Cow;
2+
13
use crate::{
24
exceptions::PyTypeError,
35
ffi,
@@ -305,15 +307,52 @@ pub struct KeywordOnlyParameterDescription {
305307

306308
/// Function argument specification for a `#[pyfunction]` or `#[pymethod]`.
307309
pub struct FunctionDescription {
308-
pub cls_name: Option<&'static str>,
309-
pub func_name: &'static str,
310-
pub positional_parameter_names: &'static [&'static str],
311-
pub positional_only_parameters: usize,
312-
pub required_positional_parameters: usize,
313-
pub keyword_only_parameters: &'static [KeywordOnlyParameterDescription],
310+
cls_name: Option<&'static str>,
311+
func_name: &'static str,
312+
positional_parameter_names: &'static [&'static str],
313+
positional_only_parameters: usize,
314+
required_positional_parameters: usize,
315+
keyword_only_parameters: &'static [KeywordOnlyParameterDescription],
316+
/// Optimized lookup of keyword argument names to parameter indices.
317+
argument_lookup_by_name: fn(&str) -> Option<usize>,
318+
/// Whether any keywords are required, useful to avoid iterating if not needed.
319+
has_required_keywords: bool,
314320
}
315321

316322
impl FunctionDescription {
323+
pub const fn new(
324+
cls_name: Option<&'static str>,
325+
func_name: &'static str,
326+
positional_parameter_names: &'static [&'static str],
327+
positional_only_parameters: usize,
328+
required_positional_parameters: usize,
329+
keyword_only_parameters: &'static [KeywordOnlyParameterDescription],
330+
argument_lookup_by_name: fn(&str) -> Option<usize>,
331+
) -> Self {
332+
let mut has_required_keywords = false;
333+
let mut idx = 0;
334+
loop {
335+
if idx >= keyword_only_parameters.len() {
336+
break;
337+
}
338+
if keyword_only_parameters[idx].required {
339+
has_required_keywords = true;
340+
break;
341+
}
342+
idx += 1;
343+
}
344+
Self {
345+
cls_name,
346+
func_name,
347+
positional_parameter_names,
348+
positional_only_parameters,
349+
required_positional_parameters,
350+
keyword_only_parameters,
351+
argument_lookup_by_name,
352+
has_required_keywords,
353+
}
354+
}
355+
317356
fn full_name(&self) -> String {
318357
if let Some(cls_name) = self.cls_name {
319358
format!("{}.{}()", cls_name, self.func_name)
@@ -357,18 +396,17 @@ impl FunctionDescription {
357396
// - we both have the GIL and can borrow these input references for the `'py` lifetime.
358397
let args: *const Option<PyArg<'py>> = args.cast();
359398
let positional_args_provided = nargs as usize;
360-
let remaining_positional_args = if args.is_null() {
361-
debug_assert_eq!(positional_args_provided, 0);
399+
let remaining_positional_args = if args.is_null() || positional_args_provided == 0 {
400+
debug_assert_eq!(nargs, 0); // in case args is null
362401
&[]
363402
} else {
364403
// Can consume at most the number of positional parameters in the function definition,
365404
// the rest are varargs.
366405
let positional_args_to_consume =
367406
num_positional_parameters.min(positional_args_provided);
368-
let (positional_parameters, remaining) = unsafe {
369-
std::slice::from_raw_parts(args, positional_args_provided)
370-
.split_at(positional_args_to_consume)
371-
};
407+
let (positional_parameters, remaining) =
408+
unsafe { std::slice::from_raw_parts(args, positional_args_provided) }
409+
.split_at(positional_args_to_consume);
372410
output[..positional_args_to_consume].copy_from_slice(positional_parameters);
373411
remaining
374412
};
@@ -499,8 +537,8 @@ impl FunctionDescription {
499537
output.len(),
500538
num_positional_parameters + self.keyword_only_parameters.len()
501539
);
502-
let mut positional_only_keyword_arguments = Vec::new();
503-
for (kwarg_name_py, value) in kwargs {
540+
let mut kwargs = kwargs.into_iter();
541+
while let Some((kwarg_name_py, value)) = kwargs.next() {
504542
// Safety: All keyword arguments should be UTF-8 strings, but if it's not, `.to_str()`
505543
// will return an error anyway.
506544
#[cfg(any(Py_3_10, not(Py_LIMITED_API)))]
@@ -509,32 +547,21 @@ impl FunctionDescription {
509547

510548
#[cfg(all(not(Py_3_10), Py_LIMITED_API))]
511549
let kwarg_name = kwarg_name_py.extract::<crate::pybacked::PyBackedStr>();
550+
#[cfg(all(not(Py_3_10), Py_LIMITED_API))]
551+
let kwarg_name = kwarg_name.as_deref();
512552

513-
if let Ok(kwarg_name_owned) = kwarg_name {
514-
#[cfg(any(Py_3_10, not(Py_LIMITED_API)))]
515-
let kwarg_name = kwarg_name_owned;
516-
#[cfg(all(not(Py_3_10), Py_LIMITED_API))]
517-
let kwarg_name: &str = &kwarg_name_owned;
518-
519-
// Try to place parameter in keyword only parameters
520-
if let Some(i) = self.find_keyword_parameter_in_keyword_only(kwarg_name) {
521-
if output[i + num_positional_parameters]
522-
.replace(value)
523-
.is_some()
524-
{
525-
return Err(self.multiple_values_for_argument(kwarg_name));
526-
}
527-
continue;
528-
}
529-
530-
// Repeat for positional parameters
531-
if let Some(i) = self.find_keyword_parameter_in_positional(kwarg_name) {
553+
if let Ok(kwarg_name) = kwarg_name {
554+
if let Some(i) = (self.argument_lookup_by_name)(kwarg_name) {
532555
if i < self.positional_only_parameters {
533556
// If accepting **kwargs, then it's allowed for the name of the
534557
// kwarg to conflict with a postional-only argument - the value
535558
// will go into **kwargs anyway.
536559
if K::handle_varkeyword(varkeywords, kwarg_name_py, value, self).is_err() {
537-
positional_only_keyword_arguments.push(kwarg_name_owned);
560+
// otherwise, can bail out into an error pathway
561+
return Err(self.positional_only_keyword_arguments(
562+
kwarg_name,
563+
kwargs.map(|(k, _)| k),
564+
));
538565
}
539566
} else if output[i].replace(value).is_some() {
540567
return Err(self.multiple_values_for_argument(kwarg_name));
@@ -546,35 +573,9 @@ impl FunctionDescription {
546573
K::handle_varkeyword(varkeywords, kwarg_name_py, value, self)?
547574
}
548575

549-
if !positional_only_keyword_arguments.is_empty() {
550-
#[cfg(all(not(Py_3_10), Py_LIMITED_API))]
551-
let positional_only_keyword_arguments: Vec<_> = positional_only_keyword_arguments
552-
.iter()
553-
.map(std::ops::Deref::deref)
554-
.collect();
555-
return Err(self.positional_only_keyword_arguments(&positional_only_keyword_arguments));
556-
}
557-
558576
Ok(())
559577
}
560578

561-
#[inline]
562-
fn find_keyword_parameter_in_positional(&self, kwarg_name: &str) -> Option<usize> {
563-
self.positional_parameter_names
564-
.iter()
565-
.position(|&param_name| param_name == kwarg_name)
566-
}
567-
568-
#[inline]
569-
fn find_keyword_parameter_in_keyword_only(&self, kwarg_name: &str) -> Option<usize> {
570-
// Compare the keyword name against each parameter in turn. This is exactly the same method
571-
// which CPython uses to map keyword names. Although it's O(num_parameters), the number of
572-
// parameters is expected to be small so it's not worth constructing a mapping.
573-
self.keyword_only_parameters
574-
.iter()
575-
.position(|param_desc| param_desc.name == kwarg_name)
576-
}
577-
578579
#[inline]
579580
fn ensure_no_missing_required_positional_arguments(
580581
&self,
@@ -596,6 +597,9 @@ impl FunctionDescription {
596597
&self,
597598
output: &[Option<PyArg<'_>>],
598599
) -> PyResult<()> {
600+
if !self.has_required_keywords {
601+
return Ok(());
602+
}
599603
let keyword_output = &output[self.positional_parameter_names.len()..];
600604
for (param, out) in self.keyword_only_parameters.iter().zip(keyword_output) {
601605
if param.required && out.is_none() {
@@ -648,12 +652,31 @@ impl FunctionDescription {
648652
}
649653

650654
#[cold]
651-
fn positional_only_keyword_arguments(&self, parameter_names: &[&str]) -> PyErr {
655+
fn positional_only_keyword_arguments<'py>(
656+
&self,
657+
current_kwarg: &str,
658+
remaining_kwargs: impl IntoIterator<Item = PyArg<'py>>,
659+
) -> PyErr {
660+
let mut parameter_names = vec![Cow::Borrowed(current_kwarg)];
661+
for kwarg_name_py in remaining_kwargs {
662+
// Safety: All keyword arguments should be UTF-8 strings, but if it's not, `.to_cow()`
663+
// will return an error anyway.
664+
let kwarg_name =
665+
unsafe { kwarg_name_py.cast_unchecked::<crate::types::PyString>() }.to_cow();
666+
667+
if let Ok(kwarg_name) = kwarg_name {
668+
if (self.argument_lookup_by_name)(&*kwarg_name)
669+
.is_some_and(|i| i < self.positional_only_parameters)
670+
{
671+
parameter_names.push(kwarg_name);
672+
}
673+
}
674+
}
652675
let mut msg = format!(
653676
"{} got some positional-only arguments passed as keyword arguments: ",
654677
self.full_name()
655678
);
656-
push_parameter_list(&mut msg, parameter_names);
679+
push_parameter_list(&mut msg, &parameter_names);
657680
PyTypeError::new_err(msg)
658681
}
659682

@@ -805,7 +828,7 @@ pub struct NoVarkeywords;
805828

806829
impl<'py> VarkeywordsHandler<'py> for NoVarkeywords {
807830
type Varkeywords = ();
808-
#[inline]
831+
#[cold]
809832
fn handle_varkeyword(
810833
_varkeywords: &mut Self::Varkeywords,
811834
name: PyArg<'py>,
@@ -834,7 +857,7 @@ impl<'py> VarkeywordsHandler<'py> for DictVarkeywords {
834857
}
835858
}
836859

837-
fn push_parameter_list(msg: &mut String, parameter_names: &[&str]) {
860+
fn push_parameter_list<S: AsRef<str>>(msg: &mut String, parameter_names: &[S]) {
838861
let len = parameter_names.len();
839862
for (i, parameter) in parameter_names.iter().enumerate() {
840863
if i != 0 {
@@ -850,7 +873,7 @@ fn push_parameter_list(msg: &mut String, parameter_names: &[&str]) {
850873
}
851874

852875
msg.push('\'');
853-
msg.push_str(parameter);
876+
msg.push_str(parameter.as_ref());
854877
msg.push('\'');
855878
}
856879
}
@@ -871,6 +894,8 @@ mod tests {
871894
positional_only_parameters: 0,
872895
required_positional_parameters: 0,
873896
keyword_only_parameters: &[],
897+
argument_lookup_by_name: |_: &str| None,
898+
has_required_keywords: false,
874899
};
875900

876901
Python::attach(|py| {
@@ -902,6 +927,8 @@ mod tests {
902927
positional_only_parameters: 0,
903928
required_positional_parameters: 0,
904929
keyword_only_parameters: &[],
930+
argument_lookup_by_name: |_: &str| None,
931+
has_required_keywords: false,
905932
};
906933

907934
Python::attach(|py| {
@@ -933,6 +960,8 @@ mod tests {
933960
positional_only_parameters: 0,
934961
required_positional_parameters: 2,
935962
keyword_only_parameters: &[],
963+
argument_lookup_by_name: |_: &str| None,
964+
has_required_keywords: false,
936965
};
937966

938967
Python::attach(|py| {
@@ -957,7 +986,7 @@ mod tests {
957986
#[test]
958987
fn push_parameter_list_empty() {
959988
let mut s = String::new();
960-
push_parameter_list(&mut s, &[]);
989+
push_parameter_list::<&str>(&mut s, &[]);
961990
assert_eq!(&s, "");
962991
}
963992

0 commit comments

Comments
 (0)