Skip to content

Commit b3ee70d

Browse files
committed
refactor parse_fn_type
1 parent 1158c08 commit b3ee70d

File tree

1 file changed

+41
-41
lines changed

1 file changed

+41
-41
lines changed

pyo3-macros-backend/src/method.rs

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ pub enum FnType {
8585
}
8686

8787
impl FnType {
88+
pub fn skip_first_rust_argument_in_python_signature(&self) -> bool {
89+
match self {
90+
FnType::Getter(_)
91+
| FnType::Setter(_)
92+
| FnType::Fn(_)
93+
| FnType::FnClass
94+
| FnType::FnNewClass
95+
| FnType::FnModule => true,
96+
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => false,
97+
}
98+
}
99+
88100
pub fn self_arg(&self, cls: Option<&syn::Type>, error_mode: ExtractErrorMode) -> TokenStream {
89101
match self {
90102
FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) => {
@@ -264,35 +276,35 @@ impl<'a> FnSpec<'a> {
264276

265277
let mut python_name = name.map(|name| name.value.0);
266278

267-
let (fn_type, skip_first_arg, fixed_convention) =
268-
Self::parse_fn_type(sig, meth_attrs, &mut python_name)?;
279+
let fn_type = Self::parse_fn_type(sig, meth_attrs, &mut python_name)?;
269280
ensure_signatures_on_valid_method(&fn_type, signature.as_ref(), text_signature.as_ref())?;
270281

271282
let name = &sig.ident;
272283
let ty = get_return_info(&sig.output);
273284
let python_name = python_name.as_ref().unwrap_or(name).unraw();
274285

275-
let arguments: Vec<_> = if skip_first_arg {
276-
sig.inputs
277-
.iter_mut()
278-
.skip(1)
279-
.map(FnArg::parse)
280-
.collect::<Result<_>>()?
281-
} else {
282-
sig.inputs
283-
.iter_mut()
284-
.map(FnArg::parse)
285-
.collect::<Result<_>>()?
286-
};
286+
let arguments: Vec<_> = sig
287+
.inputs
288+
.iter_mut()
289+
.skip(if fn_type.skip_first_rust_argument_in_python_signature() {
290+
1
291+
} else {
292+
0
293+
})
294+
.map(FnArg::parse)
295+
.collect::<Result<_>>()?;
287296

288297
let signature = if let Some(signature) = signature {
289298
FunctionSignature::from_arguments_and_attribute(arguments, signature)?
290299
} else {
291300
FunctionSignature::from_arguments(arguments)?
292301
};
293302

294-
let convention =
295-
fixed_convention.unwrap_or_else(|| CallingConvention::from_signature(&signature));
303+
let convention = if matches!(fn_type, FnType::FnNew | FnType::FnNewClass) {
304+
CallingConvention::TpNew
305+
} else {
306+
CallingConvention::from_signature(&signature)
307+
};
296308

297309
Ok(FnSpec {
298310
tp: fn_type,
@@ -314,7 +326,7 @@ impl<'a> FnSpec<'a> {
314326
sig: &syn::Signature,
315327
meth_attrs: &mut Vec<syn::Attribute>,
316328
python_name: &mut Option<syn::Ident>,
317-
) -> Result<(FnType, bool, Option<CallingConvention>)> {
329+
) -> Result<FnType> {
318330
let mut method_attributes = parse_method_attributes(meth_attrs)?;
319331

320332
let name = &sig.ident;
@@ -334,16 +346,12 @@ impl<'a> FnSpec<'a> {
334346
.map(|stripped| syn::Ident::new(stripped, name.span()))
335347
};
336348

337-
let (fn_type, skip_first_arg, fixed_convention) = match method_attributes.as_mut_slice() {
338-
[] => (
339-
FnType::Fn(parse_receiver(
340-
"static method needs #[staticmethod] attribute",
341-
)?),
342-
true,
343-
None,
344-
),
345-
[MethodTypeAttribute::StaticMethod(_)] => (FnType::FnStatic, false, None),
346-
[MethodTypeAttribute::ClassAttribute(_)] => (FnType::ClassAttribute, false, None),
349+
let fn_type = match method_attributes.as_mut_slice() {
350+
[] => FnType::Fn(parse_receiver(
351+
"static method needs #[staticmethod] attribute",
352+
)?),
353+
[MethodTypeAttribute::StaticMethod(_)] => FnType::FnStatic,
354+
[MethodTypeAttribute::ClassAttribute(_)] => FnType::ClassAttribute,
347355
[MethodTypeAttribute::New(_)]
348356
| [MethodTypeAttribute::New(_), MethodTypeAttribute::ClassMethod(_)]
349357
| [MethodTypeAttribute::ClassMethod(_), MethodTypeAttribute::New(_)] => {
@@ -352,12 +360,12 @@ impl<'a> FnSpec<'a> {
352360
}
353361
*python_name = Some(syn::Ident::new("__new__", Span::call_site()));
354362
if matches!(method_attributes.as_slice(), [MethodTypeAttribute::New(_)]) {
355-
(FnType::FnNew, false, Some(CallingConvention::TpNew))
363+
FnType::FnNew
356364
} else {
357-
(FnType::FnNewClass, true, Some(CallingConvention::TpNew))
365+
FnType::FnNewClass
358366
}
359367
}
360-
[MethodTypeAttribute::ClassMethod(_)] => (FnType::FnClass, true, None),
368+
[MethodTypeAttribute::ClassMethod(_)] => FnType::FnClass,
361369
[MethodTypeAttribute::Getter(_, name)] => {
362370
if let Some(name) = name.take() {
363371
ensure_spanned!(
@@ -369,11 +377,7 @@ impl<'a> FnSpec<'a> {
369377
*python_name = strip_fn_name("get_");
370378
}
371379

372-
(
373-
FnType::Getter(parse_receiver("expected receiver for `#[getter]`")?),
374-
true,
375-
None,
376-
)
380+
FnType::Getter(parse_receiver("expected receiver for `#[getter]`")?)
377381
}
378382
[MethodTypeAttribute::Setter(_, name)] => {
379383
if let Some(name) = name.take() {
@@ -386,11 +390,7 @@ impl<'a> FnSpec<'a> {
386390
*python_name = strip_fn_name("set_");
387391
}
388392

389-
(
390-
FnType::Setter(parse_receiver("expected receiver for `#[setter]`")?),
391-
true,
392-
None,
393-
)
393+
FnType::Setter(parse_receiver("expected receiver for `#[setter]`")?)
394394
}
395395
[first, rest @ .., last] => {
396396
// Join as many of the spans together as possible
@@ -416,7 +416,7 @@ impl<'a> FnSpec<'a> {
416416
bail_spanned!(span => msg)
417417
}
418418
};
419-
Ok((fn_type, skip_first_arg, fixed_convention))
419+
Ok(fn_type)
420420
}
421421

422422
/// Return a C wrapper function for this signature.

0 commit comments

Comments
 (0)