22
33use crate :: {
44 attributes:: { self , take_attributes, take_pyo3_options, CrateAttribute , NameAttribute } ,
5+ get_doc,
56 pyfunction:: { impl_wrap_pyfunction, PyFunctionOptions } ,
6- utils:: { get_pyo3_crate, PythonDoc } ,
7+ utils:: get_pyo3_crate,
78} ;
89use proc_macro2:: TokenStream ;
910use quote:: quote;
@@ -12,7 +13,7 @@ use syn::{
1213 parse:: { Parse , ParseStream } ,
1314 spanned:: Spanned ,
1415 token:: Comma ,
15- Ident , Path , Result , Visibility ,
16+ Item , Path , Result ,
1617} ;
1718
1819#[ derive( Default ) ]
@@ -56,33 +57,154 @@ impl PyModuleOptions {
5657 }
5758}
5859
60+ pub fn pymodule_module_impl ( mut module : syn:: ItemMod ) -> Result < TokenStream > {
61+ let syn:: ItemMod {
62+ attrs,
63+ vis,
64+ unsafety : _,
65+ ident,
66+ mod_token : _,
67+ content,
68+ semi : _,
69+ } = & mut module;
70+ let items = if let Some ( ( _, items) ) = content {
71+ items
72+ } else {
73+ bail_spanned ! ( module. span( ) => "`#[pymodule]` can only be used on inline modules" )
74+ } ;
75+ let options = PyModuleOptions :: from_attrs ( attrs) ?;
76+ let krate = get_pyo3_crate ( & options. krate ) ;
77+ let doc = get_doc ( attrs, None ) ;
78+
79+ let mut module_items = Vec :: new ( ) ;
80+ let mut module_items_cfg_attrs = Vec :: new ( ) ;
81+
82+ fn extract_use_items (
83+ source : & syn:: UseTree ,
84+ cfg_attrs : & [ syn:: Attribute ] ,
85+ target_items : & mut Vec < syn:: Ident > ,
86+ target_cfg_attrs : & mut Vec < Vec < syn:: Attribute > > ,
87+ ) -> Result < ( ) > {
88+ match source {
89+ syn:: UseTree :: Name ( name) => {
90+ target_items. push ( name. ident . clone ( ) ) ;
91+ target_cfg_attrs. push ( cfg_attrs. to_vec ( ) ) ;
92+ }
93+ syn:: UseTree :: Path ( path) => {
94+ extract_use_items ( & path. tree , cfg_attrs, target_items, target_cfg_attrs) ?
95+ }
96+ syn:: UseTree :: Group ( group) => {
97+ for tree in & group. items {
98+ extract_use_items ( tree, cfg_attrs, target_items, target_cfg_attrs) ?
99+ }
100+ }
101+ syn:: UseTree :: Glob ( glob) => {
102+ bail_spanned ! ( glob. span( ) => "#[pymodule] cannot import glob statements" )
103+ }
104+ syn:: UseTree :: Rename ( rename) => {
105+ target_items. push ( rename. rename . clone ( ) ) ;
106+ target_cfg_attrs. push ( cfg_attrs. to_vec ( ) ) ;
107+ }
108+ }
109+ Ok ( ( ) )
110+ }
111+
112+ let mut pymodule_init = None ;
113+
114+ for item in & mut * items {
115+ match item {
116+ Item :: Use ( item_use) => {
117+ let mut is_pyo3 = false ;
118+ item_use. attrs . retain ( |attr| {
119+ let found = attr. path ( ) . is_ident ( "pymodule_export" ) ;
120+ is_pyo3 |= found;
121+ !found
122+ } ) ;
123+ if is_pyo3 {
124+ let cfg_attrs = item_use
125+ . attrs
126+ . iter ( )
127+ . filter ( |attr| attr. path ( ) . is_ident ( "cfg" ) )
128+ . cloned ( )
129+ . collect :: < Vec < _ > > ( ) ;
130+ extract_use_items (
131+ & item_use. tree ,
132+ & cfg_attrs,
133+ & mut module_items,
134+ & mut module_items_cfg_attrs,
135+ ) ?;
136+ }
137+ }
138+ Item :: Fn ( item_fn) => {
139+ let mut is_module_init = false ;
140+ item_fn. attrs . retain ( |attr| {
141+ let found = attr. path ( ) . is_ident ( "pymodule_init" ) ;
142+ is_module_init |= found;
143+ !found
144+ } ) ;
145+ if is_module_init {
146+ ensure_spanned ! ( pymodule_init. is_none( ) , item_fn. span( ) => "only one pymodule_init may be specified" ) ;
147+ let ident = & item_fn. sig . ident ;
148+ pymodule_init = Some ( quote ! { #ident( module) ?; } ) ;
149+ } else {
150+ bail_spanned ! ( item. span( ) => "only 'use' statements and and pymodule_init functions are allowed in #[pymodule]" )
151+ }
152+ }
153+ item => {
154+ bail_spanned ! ( item. span( ) => "only 'use' statements and and pymodule_init functions are allowed in #[pymodule]" )
155+ }
156+ }
157+ }
158+
159+ let initialization = module_initialization ( options, ident) ;
160+ Ok ( quote ! (
161+ #vis mod #ident {
162+ #( #items) *
163+
164+ #initialization
165+
166+ impl MakeDef {
167+ const fn make_def( ) -> #krate:: impl_:: pymodule:: ModuleDef {
168+ use #krate:: impl_:: pymodule as impl_;
169+ const INITIALIZER : impl_:: ModuleInitializer = impl_:: ModuleInitializer ( __pyo3_pymodule) ;
170+ unsafe {
171+ impl_:: ModuleDef :: new(
172+ __PYO3_NAME,
173+ #doc,
174+ INITIALIZER
175+ )
176+ }
177+ }
178+ }
179+
180+ fn __pyo3_pymodule( module: & #krate:: Bound <' _, #krate:: types:: PyModule >) -> #krate:: PyResult <( ) > {
181+ use #krate:: impl_:: pymodule:: PyAddToModule ;
182+ #(
183+ #( #module_items_cfg_attrs) *
184+ #module_items:: add_to_module( module) ?;
185+ ) *
186+ #pymodule_init
187+ Ok ( ( ) )
188+ }
189+ }
190+ ) )
191+ }
192+
59193/// Generates the function that is called by the python interpreter to initialize the native
60194/// module
61- pub fn pymodule_impl (
62- fnname : & Ident ,
63- options : PyModuleOptions ,
64- doc : PythonDoc ,
65- visibility : & Visibility ,
66- ) -> TokenStream {
67- let name = options. name . unwrap_or_else ( || fnname. unraw ( ) ) ;
195+ pub fn pymodule_function_impl ( mut function : syn:: ItemFn ) -> Result < TokenStream > {
196+ let options = PyModuleOptions :: from_attrs ( & mut function. attrs ) ?;
197+ process_functions_in_module ( & options, & mut function) ?;
68198 let krate = get_pyo3_crate ( & options. krate ) ;
69- let pyinit_symbol = format ! ( "PyInit_{}" , name) ;
199+ let ident = & function. sig . ident ;
200+ let vis = & function. vis ;
201+ let doc = get_doc ( & function. attrs , None ) ;
70202
71- quote ! {
72- // Create a module with the same name as the `#[pymodule]` - this way `use <the module>`
73- // will actually bring both the module and the function into scope.
74- #[ doc( hidden) ]
75- #visibility mod #fnname {
76- pub ( crate ) struct MakeDef ;
77- pub static DEF : #krate:: impl_:: pymodule:: ModuleDef = MakeDef :: make_def( ) ;
78- pub const NAME : & ' static str = concat!( stringify!( #name) , "\0 " ) ;
79-
80- /// This autogenerated function is called by the python interpreter when importing
81- /// the module.
82- #[ export_name = #pyinit_symbol]
83- pub unsafe extern "C" fn init( ) -> * mut #krate:: ffi:: PyObject {
84- #krate:: impl_:: trampoline:: module_init( |py| DEF . make_module( py) )
85- }
203+ let initialization = module_initialization ( options, ident) ;
204+ Ok ( quote ! {
205+ #function
206+ #vis mod #ident {
207+ #initialization
86208 }
87209
88210 // Generate the definition inside an anonymous function in the same scope as the original function -
@@ -91,28 +213,59 @@ pub fn pymodule_impl(
91213 // inside a function body)
92214 const _: ( ) = {
93215 use #krate:: impl_:: pymodule as impl_;
94- impl #fnname:: MakeDef {
216+
217+ fn __pyo3_pymodule( module: & #krate:: Bound <' _, #krate:: types:: PyModule >) -> #krate:: PyResult <( ) > {
218+ #ident( module. py( ) , module. as_gil_ref( ) )
219+ }
220+
221+ impl #ident:: MakeDef {
95222 const fn make_def( ) -> impl_:: ModuleDef {
96- const INITIALIZER : impl_:: ModuleInitializer = impl_:: ModuleInitializer ( #fnname) ;
97223 unsafe {
98- impl_:: ModuleDef :: new( #fnname:: NAME , #doc, INITIALIZER )
224+ const INITIALIZER : impl_:: ModuleInitializer = impl_:: ModuleInitializer ( __pyo3_pymodule) ;
225+ impl_:: ModuleDef :: new(
226+ #ident:: __PYO3_NAME,
227+ #doc,
228+ INITIALIZER
229+ )
99230 }
100231 }
101232 }
102233 } ;
234+ } )
235+ }
236+
237+ fn module_initialization ( options : PyModuleOptions , ident : & syn:: Ident ) -> TokenStream {
238+ let name = options. name . unwrap_or_else ( || ident. unraw ( ) ) ;
239+ let krate = get_pyo3_crate ( & options. krate ) ;
240+ let pyinit_symbol = format ! ( "PyInit_{}" , name) ;
241+
242+ quote ! {
243+ pub const __PYO3_NAME: & ' static str = concat!( stringify!( #name) , "\0 " ) ;
244+
245+ pub ( super ) struct MakeDef ;
246+ pub static DEF : #krate:: impl_:: pymodule:: ModuleDef = MakeDef :: make_def( ) ;
247+
248+ pub fn add_to_module( module: & #krate:: Bound <' _, #krate:: types:: PyModule >) -> #krate:: PyResult <( ) > {
249+ use #krate:: prelude:: PyModuleMethods ;
250+ module. add_submodule( DEF . make_module( module. py( ) ) ?. bind( module. py( ) ) )
251+ }
252+
253+ /// This autogenerated function is called by the python interpreter when importing
254+ /// the module.
255+ #[ export_name = #pyinit_symbol]
256+ pub unsafe extern "C" fn __pyo3_init( ) -> * mut #krate:: ffi:: PyObject {
257+ #krate:: impl_:: trampoline:: module_init( |py| DEF . make_module( py) )
258+ }
103259 }
104260}
105261
106262/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
107- pub fn process_functions_in_module (
108- options : & PyModuleOptions ,
109- func : & mut syn:: ItemFn ,
110- ) -> syn:: Result < ( ) > {
263+ fn process_functions_in_module ( options : & PyModuleOptions , func : & mut syn:: ItemFn ) -> Result < ( ) > {
111264 let mut stmts: Vec < syn:: Stmt > = Vec :: new ( ) ;
112265 let krate = get_pyo3_crate ( & options. krate ) ;
113266
114267 for mut stmt in func. block . stmts . drain ( ..) {
115- if let syn:: Stmt :: Item ( syn :: Item :: Fn ( func) ) = & mut stmt {
268+ if let syn:: Stmt :: Item ( Item :: Fn ( func) ) = & mut stmt {
116269 if let Some ( pyfn_args) = get_pyfn_attr ( & mut func. attrs ) ? {
117270 let module_name = pyfn_args. modname ;
118271 let wrapped_function = impl_wrap_pyfunction ( func, pyfn_args. options ) ?;
0 commit comments