Skip to content

Commit 7a660ad

Browse files
committed
pyo3: support module prefix + naming
1 parent a3bb997 commit 7a660ad

File tree

4 files changed

+72
-6
lines changed

4 files changed

+72
-6
lines changed

extensions/pyo3/private/pyo3.bzl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,20 @@ def _py_pyo3_library_impl(ctx):
8787
is_windows = extension.basename.endswith(".dll")
8888

8989
# https://pyo3.rs/v0.26.0/building-and-distribution#manual-builds
90-
ext = ctx.actions.declare_file("{}{}".format(
91-
ctx.label.name,
92-
".pyd" if is_windows else ".so",
93-
))
90+
# Determine the on-disk and logical Python module layout.
91+
module_name = ctx.attr.module if hasattr(ctx.attr, "module") and ctx.attr.module else ctx.label.name
92+
93+
# Convert a dotted prefix (e.g. "foo.bar") into a path ("foo/bar").
94+
module_prefix = ctx.attr.module_prefix if hasattr(ctx.attr, "module_prefix") and ctx.attr.module_prefix else None
95+
if module_prefix:
96+
module_prefix_path = module_prefix.replace(".", "/")
97+
module_relpath = "{}/{}.{}".format(module_prefix_path, module_name, "pyd" if is_windows else "so")
98+
stub_relpath = "{}/{}.pyi".format(module_prefix_path, module_name)
99+
else:
100+
module_relpath = "{}.{}".format(module_name, "pyd" if is_windows else "so")
101+
stub_relpath = "{}.pyi".format(module_name)
102+
103+
ext = ctx.actions.declare_file(module_relpath)
94104
ctx.actions.symlink(
95105
output = ext,
96106
target_file = extension,
@@ -99,10 +109,10 @@ def _py_pyo3_library_impl(ctx):
99109

100110
stub = None
101111
if _stubs_enabled(ctx.attr.stubs, toolchain):
102-
stub = ctx.actions.declare_file("{}.pyi".format(ctx.label.name))
112+
stub = ctx.actions.declare_file(stub_relpath)
103113

104114
args = ctx.actions.args()
105-
args.add(ctx.label.name, format = "--module_name=%s")
115+
args.add(module_name, format = "--module_name=%s")
106116
args.add(ext, format = "--module_path=%s")
107117
args.add(stub, format = "--output=%s")
108118
ctx.actions.run(
@@ -180,6 +190,12 @@ py_pyo3_library = rule(
180190
"imports": attr.string_list(
181191
doc = "List of import directories to be added to the `PYTHONPATH`.",
182192
),
193+
"module": attr.string(
194+
doc = "The Python module name implemented by this extension.",
195+
),
196+
"module_prefix": attr.string(
197+
doc = "A dotted Python package prefix for the module.",
198+
),
183199
"stubs": attr.int(
184200
doc = "Whether or not to generate stubs. `-1` will default to the global config, `0` will never generate, and `1` will always generate stubs.",
185201
default = -1,
@@ -218,6 +234,8 @@ def pyo3_extension(
218234
stubs = None,
219235
version = None,
220236
compilation_mode = "opt",
237+
module = None,
238+
module_prefix = None,
221239
**kwargs):
222240
"""Define a PyO3 python extension module.
223241
@@ -259,6 +277,8 @@ def pyo3_extension(
259277
For more details see [rust_shared_library][rsl].
260278
compilation_mode (str, optional): The [compilation_mode](https://bazel.build/reference/command-line-reference#flag--compilation_mode)
261279
value to build the extension for. If set to `"current"`, the current configuration will be used.
280+
module (str, optional): The Python module name implemented by this extension.
281+
module_prefix (str, optional): A dotted Python package prefix for the module.
262282
**kwargs (dict): Additional keyword arguments.
263283
"""
264284
tags = kwargs.pop("tags", [])
@@ -318,6 +338,8 @@ def pyo3_extension(
318338
compilation_mode = compilation_mode,
319339
stubs = stubs_int,
320340
imports = imports,
341+
module = module,
342+
module_prefix = module_prefix,
321343
tags = tags,
322344
visibility = visibility,
323345
**kwargs
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
load("@rules_python//python:defs.bzl", "py_test")
2+
load("//:defs.bzl", "pyo3_extension")
3+
4+
pyo3_extension(
5+
name = "module_prefix",
6+
srcs = ["bar.rs"],
7+
edition = "2021",
8+
imports = ["."],
9+
module = "bar",
10+
module_prefix = "foo",
11+
)
12+
13+
py_test(
14+
name = "module_prefix_import_test",
15+
srcs = ["module_prefix_import_test.py"],
16+
deps = [":module_prefix"],
17+
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use pyo3::prelude::*;
2+
3+
#[pyfunction]
4+
fn thing() -> PyResult<&'static str> {
5+
Ok("hello from rust")
6+
}
7+
8+
#[pymodule]
9+
fn bar(m: &Bound<'_, PyModule>) -> PyResult<()> {
10+
m.add_function(wrap_pyfunction!(thing, m)?)?;
11+
Ok(())
12+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Tests that a pyo3 extension can be imported via a module prefix."""
2+
3+
import unittest
4+
5+
import foo.bar
6+
7+
8+
class ModulePrefixImportTest(unittest.TestCase):
9+
def test_import_and_call(self) -> None:
10+
result = foo.bar.thing()
11+
self.assertEqual("hello from rust", result)
12+
13+
14+
if __name__ == "__main__":
15+
unittest.main()

0 commit comments

Comments
 (0)