Skip to content

BitGenerator support #499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
06d6ce1
WIP bitgen
flying-sheep Jun 6, 2025
07e2416
nonnull
flying-sheep Jun 6, 2025
b611943
fix and test
flying-sheep Jun 6, 2025
d93a264
cmt
flying-sheep Jun 6, 2025
f52b2fa
safer: don’t allow trying to get `BitGen` from any PyAny
flying-sheep Jun 7, 2025
05814d6
less indirection
flying-sheep Jun 7, 2025
37d360e
add tryfrom
flying-sheep Jun 7, 2025
eed5b19
implement rand
flying-sheep Jun 7, 2025
6c1a89b
fmt
flying-sheep Jun 7, 2025
d1909d3
rename and deref
flying-sheep Jun 8, 2025
bde2553
order
flying-sheep Jun 8, 2025
a0b9ec5
make into lock
flying-sheep Jun 8, 2025
ee32246
docs
flying-sheep Jun 8, 2025
1be6838
more docs
flying-sheep Jun 8, 2025
2aa3d90
guard
flying-sheep Jun 8, 2025
0258e6d
call_method0
flying-sheep Jun 8, 2025
876001b
reaname test
flying-sheep Jun 8, 2025
71ce8be
manually drop and capsule
flying-sheep Jun 8, 2025
2de7072
remove useless test
flying-sheep Jun 8, 2025
016eb7a
doctests
flying-sheep Jun 8, 2025
1f7f37f
smaller
flying-sheep Jun 8, 2025
1d01c7a
clarify where to release the GIL
flying-sheep Jun 8, 2025
c90176a
safety
flying-sheep Jun 8, 2025
f49d3fa
oops
flying-sheep Jun 8, 2025
a16846d
less unsafe
flying-sheep Jun 8, 2025
573d890
add thread test
flying-sheep Jun 8, 2025
06bb693
back to lock acquiring
flying-sheep Jun 8, 2025
663fa29
docs
flying-sheep Jun 9, 2025
c6105c9
no copy/clone
flying-sheep Jun 10, 2025
3a0aa92
rename to release
flying-sheep Jun 10, 2025
a92861a
remove lifetime
flying-sheep Jun 10, 2025
6dbb6dc
static
flying-sheep Jun 10, 2025
b102d20
no mut ref conversion
flying-sheep Jun 10, 2025
e5e440e
disambiguate
flying-sheep Jun 10, 2025
e73e3a2
rand_core only
flying-sheep Jun 10, 2025
c6493df
rename bitgen type
flying-sheep Jun 10, 2025
2327f36
c_str macro
flying-sheep Jun 10, 2025
e5c6458
intern strings
flying-sheep Jun 10, 2025
e8cd5e8
docs
flying-sheep Jun 10, 2025
0868405
more doc
flying-sheep Jun 10, 2025
8667203
clean up tests
flying-sheep Jun 10, 2025
1fd7bb5
no let-else
flying-sheep Jun 10, 2025
3913171
use GILOnceCell::import
flying-sheep Jun 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .vscode/settings.json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will do when I’m done. I like working on multiple machines, and I don’t like re-doing settings for individual projects

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"[rust]": {
"editor.defaultFormatter": "rust-lang.rust-analyzer",
"editor.formatOnSave": true,
},
"rust-analyzer.cargo.features": "all",
}
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ num-integer = "0.1"
num-traits = "0.2"
ndarray = ">= 0.15, < 0.17"
pyo3 = { version = "0.25.0", default-features = false, features = ["macros"] }
rand_core = { version = "0.9.3", default-features = false, optional = true }
rustc-hash = "2.0"

[dev-dependencies]
Expand All @@ -32,6 +33,7 @@ pyo3 = { version = "0.25", default-features = false, features = [
nalgebra = { version = ">=0.30, <0.34", default-features = false, features = [
"std",
] }
rand = { version = "0.9.1", default-features = false }

[build-dependencies]
pyo3-build-config = { version = "0.25", features = ["resolve-config"] }
Expand Down
2 changes: 1 addition & 1 deletion examples/simple/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ fn rust_ext<'py>(m: &Bound<'py, PyModule>) -> PyResult<()> {
// This crate follows a strongly-typed approach to wrapping NumPy arrays
// while Python API are often expected to work with multiple element types.
//
// That kind of limited polymorphis can be recovered by accepting an enumerated type
// That kind of limited polymorphism can be recovered by accepting an enumerated type
// covering the supported element types and dispatching into a generic implementation.
#[derive(FromPyObject)]
enum SupportedArray<'py> {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub mod datetime;
mod dtype;
mod error;
pub mod npyffi;
pub mod random;
mod slice_container;
mod strings;
mod sum_products;
Expand Down
2 changes: 2 additions & 0 deletions src/npyffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ macro_rules! impl_api {
pub mod array;
pub mod flags;
pub mod objects;
pub mod random;
pub mod types;
pub mod ufunc;

pub use self::array::*;
pub use self::flags::*;
pub use self::objects::*;
pub use self::random::*;
pub use self::types::*;
pub use self::ufunc::*;
2 changes: 1 addition & 1 deletion src/npyffi/objects.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! Low-Lebel binding for NumPy C API C-objects
//! Low-Level binding for NumPy C API C-objects
//!
//! <https://numpy.org/doc/stable/reference/c-api/types-and-structures.html>
#![allow(non_camel_case_types)]
Expand Down
11 changes: 11 additions & 0 deletions src/npyffi/random.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use std::ffi::c_void;

#[repr(C)]
#[derive(Debug)]
pub struct bitgen_t {
pub state: *mut c_void,
pub next_uint64: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil
pub next_uint32: unsafe extern "C" fn(*mut c_void) -> super::npy_uint32, //nogil
pub next_double: unsafe extern "C" fn(*mut c_void) -> libc::c_double, //nogil
pub next_raw: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil
}
302 changes: 302 additions & 0 deletions src/random.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
//! Safe interface for NumPy's random [`BitGenerator`][bg].
//!
//! Using the patterns described in [“Extending `numpy.random`”][ext],
//! you can generate random numbers without holding the GIL,
//! by [acquiring][`PyBitGeneratorMethods::lock`] a lock [guard][`PyBitGeneratorGuard`] for the [`PyBitGenerator`]:
//!
//! ```
//! use pyo3::prelude::*;
//! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _};
//!
//! fn default_bit_gen<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyBitGenerator>> {
//! let default_rng = py.import("numpy.random")?.call_method0("default_rng")?;
//! let bit_generator = default_rng.getattr("bit_generator")?.downcast_into()?;
//! Ok(bit_generator)
//! }
//!
//! let random_number = Python::with_gil(|py| -> PyResult<_> {
//! let mut bitgen = default_bit_gen(py)?.lock()?;
//! // use bitgen without holding the GIL
//! let r = py.allow_threads(|| bitgen.next_u64());
//! // release the lock manually while holding the GIL again
//! bitgen.release(py)?;
//! Ok(r)
//! })?;
//! # Ok::<(), PyErr>(())
//! ```
//!
//! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorGuard`]:
//!
//! ```
//! # use pyo3::prelude::*;
//! use rand::Rng as _;
//! # use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _};
//! # // TODO: reuse function definition from above?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like there should be a convenient way to get this. I'm thinking about something like

impl PyBitGenerator {
     fn new(py: Python<'_>) -> PyResult<Bound<..>>;
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are many implementations, we’d have to cover all of them.

I’d rather leave this minimal until this PR is mostly done.

//! # fn default_bit_gen<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyBitGenerator>> {
//! # let default_rng = py.import("numpy.random")?.call_method0("default_rng")?;
//! # let bit_generator = default_rng.getattr("bit_generator")?.downcast_into()?;
//! # Ok(bit_generator)
//! # }
//!
//! Python::with_gil(|py| -> PyResult<_> {
//! let mut bitgen = default_bit_gen(py)?.lock()?;
//! if bitgen.random_ratio(1, 1_000_000) {
//! println!("a sure thing");
//! }
//! bitgen.release(py)?;
//! Ok(())
//! })?;
//! # Ok::<(), PyErr>(())
//! ```
//!
//! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
//! [ext]: https://numpy.org/doc/stable/reference/random/extending.html

use std::ptr::NonNull;

use pyo3::{
exceptions::PyRuntimeError,
ffi, intern,
prelude::*,
sync::GILOnceCell,
types::{DerefToPyAny, PyCapsule, PyType},
PyTypeInfo,
};

use crate::npyffi::bitgen_t;

/// Wrapper for [`np.random.BitGenerator`][bg].
///
/// See also [`PyBitGeneratorMethods`].
///
/// [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
#[repr(transparent)]
pub struct PyBitGenerator(PyAny);

impl DerefToPyAny for PyBitGenerator {}

unsafe impl PyTypeInfo for PyBitGenerator {
const NAME: &'static str = "PyBitGenerator";
const MODULE: Option<&'static str> = Some("numpy.random");

fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
static CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
let cls = CLS
.import(py, "numpy.random", "BitGenerator")
.expect("Failed to get BitGenerator type object");
cls.as_type_ptr()
}
}

/// Methods for [`PyBitGenerator`].
pub trait PyBitGeneratorMethods {
/// Acquire a lock on the BitGenerator to allow calling its methods in.
fn lock(&self) -> PyResult<PyBitGeneratorGuard>;
}

impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> {
fn lock(&self) -> PyResult<PyBitGeneratorGuard> {
let py = self.py();
let capsule = self
.getattr(intern!(py, "capsule"))?
.downcast_into::<PyCapsule>()?;
let lock = self.getattr(intern!(py, "lock"))?;
// we’re holding the GIL, so there’s no race condition checking the lock and acquiring it later.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be true under free-threaded Python. Is the lock known to be threadsafe and acquire simply fails if the lock is already acquired? If not we may need to guard the whole module under cfg(not(Py_GIL_DISABLED))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn’t fail, it hangs, but that’s configurable with a timeout or by making it non-blocking: https://docs.python.org/3/library/threading.html#threading.Lock.acquire

and it’s a threading.Lock!

if lock.call_method0(intern!(py, "locked"))?.extract()? {
return Err(PyRuntimeError::new_err("BitGenerator is already locked"));
}
lock.call_method0(intern!(py, "acquire"))?;

assert_eq!(capsule.name()?, Some(ffi::c_str!("BitGenerator")));
let ptr = capsule.pointer() as *mut bitgen_t;
let non_null = match NonNull::new(ptr) {
Some(non_null) => non_null,
None => {
lock.call_method0(intern!(py, "release"))?;
return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule"));
}
};
Ok(PyBitGeneratorGuard {
raw_bitgen: non_null,
_capsule: capsule.unbind(),
lock: lock.unbind(),
})
}
}

impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard {
type Error = PyErr;
fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result<Self, Self::Error> {
value.lock()
}
}

/// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL.
///
/// Since [dropping](`Drop::drop`) this acquires the GIL,
/// prefer to call [`release`][`PyBitGeneratorGuard::release`] manually to release the lock.
pub struct PyBitGeneratorGuard {
raw_bitgen: NonNull<bitgen_t>,
/// This field makes sure the `raw_bitgen` inside the capsule doesn’t get deallocated.
_capsule: Py<PyCapsule>,
/// This lock makes sure no other threads try to use the BitGenerator while we do.
lock: Py<PyAny>,
}

// SAFETY: 1. We don’t hold the GIL, so we can’t access the Python objects.
// 2. We only access `raw_bitgen` from `&mut self`, which protects it from parallel access.
unsafe impl Send for PyBitGeneratorGuard {}

impl Drop for PyBitGeneratorGuard {
fn drop(&mut self) {
// ignore errors. This includes when `release` was called manually.
let _ = Python::with_gil(|py| -> PyResult<_> {
self.lock.bind(py).call_method0(intern!(py, "release"))?;
Ok(())
});
}
}

// SAFETY: 1. We hold the `BitGenerator.lock`, so nothing apart from us is allowed to change its state.
// 2. We hold the `BitGenerator.capsule`, so it can’t be deallocated.
impl<'py> PyBitGeneratorGuard {
/// Release the lock, allowing for checking for errors.
pub fn release(self, py: Python<'py>) -> PyResult<()> {
self.lock.bind(py).call_method0(intern!(py, "release"))?;
Ok(())
}

/// Returns the next random unsigned 64 bit integer.
pub fn next_u64(&mut self) -> u64 {
unsafe {
let bitgen = self.raw_bitgen.as_ptr();
debug_assert_ne!((*bitgen).state, std::ptr::null_mut());
((*bitgen).next_uint64)((*bitgen).state)
}
}
/// Returns the next random unsigned 32 bit integer.
pub fn next_u32(&mut self) -> u32 {
unsafe {
let bitgen = self.raw_bitgen.as_ptr();
debug_assert_ne!((*bitgen).state, std::ptr::null_mut());
((*bitgen).next_uint32)((*bitgen).state)
}
}
/// Returns the next random double.
pub fn next_double(&mut self) -> libc::c_double {
unsafe {
let bitgen = self.raw_bitgen.as_ptr();
debug_assert_ne!((*bitgen).state, std::ptr::null_mut());
((*bitgen).next_double)((*bitgen).state)
}
}
/// Returns the next raw value (can be used for testing).
pub fn next_raw(&mut self) -> u64 {
unsafe {
let bitgen = self.raw_bitgen.as_ptr();
debug_assert_ne!((*bitgen).state, std::ptr::null_mut());
((*bitgen).next_raw)((*bitgen).state)
}
}
}

#[cfg(feature = "rand_core")]
impl rand_core::RngCore for PyBitGeneratorGuard {
fn next_u32(&mut self) -> u32 {
PyBitGeneratorGuard::next_u32(self)
}
fn next_u64(&mut self) -> u64 {
PyBitGeneratorGuard::next_u64(self)
}
fn fill_bytes(&mut self, dst: &mut [u8]) {
rand_core::impls::fill_bytes_via_next(self, dst)
}
}

#[cfg(test)]
mod tests {
use super::*;

fn get_bit_generator<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyBitGenerator>> {
let default_rng = py.import("numpy.random")?.call_method0("default_rng")?;
let bit_generator = default_rng
.getattr("bit_generator")?
.downcast_into::<PyBitGenerator>()?;
Ok(bit_generator)
}

/// Test the primary use case: acquire the lock, release the GIL, then use the lock
#[test]
fn use_outside_gil() -> PyResult<()> {
Python::with_gil(|py| {
let mut bitgen = get_bit_generator(py)?.lock()?;
py.allow_threads(|| {
let _ = bitgen.next_raw();
});
assert!(bitgen.release(py).is_ok());
Ok(())
})
}

/// More complex version of primary use case: use from multiple threads
#[cfg(feature = "rand_core")]
#[test]
fn use_parallel() -> PyResult<()> {
use crate::array::{PyArray2, PyArrayMethods as _};
use ndarray::Dimension;
use rand::Rng;
use std::sync::{Arc, Mutex};

Python::with_gil(|py| -> PyResult<_> {
let mut arr = PyArray2::<u32>::zeros(py, (2, 300), false).readwrite();
let bitgen = get_bit_generator(py)?.lock()?;
let bitgen = Arc::new(Mutex::new(bitgen));

let (_n_threads, chunk_size) = arr.dims().into_pattern();
let slice = arr.as_slice_mut()?;

py.allow_threads(|| {
std::thread::scope(|s| {
for chunk in slice.chunks_exact_mut(chunk_size) {
let bitgen = Arc::clone(&bitgen);
s.spawn(move || {
let mut bitgen = bitgen.lock().unwrap();
chunk.fill_with(|| bitgen.random_range(10..200));
});
}
})
});

std::mem::drop(bitgen);
Ok(())
})
}

/// Test that the `rand::Rng` APIs work
#[cfg(feature = "rand_core")]
#[test]
fn rand() -> PyResult<()> {
use rand::Rng as _;

Python::with_gil(|py| {
let mut bitgen = get_bit_generator(py)?.lock()?;
py.allow_threads(|| {
assert!(bitgen.random_ratio(1, 1));
assert!(!bitgen.random_ratio(0, 1));
});
assert!(bitgen.release(py).is_ok());
Ok(())
})
}

#[test]
fn double_lock_fails() -> PyResult<()> {
Python::with_gil(|py| {
let generator = get_bit_generator(py)?;
let bitgen = generator.lock()?;
assert!(generator.lock().is_err());
assert!(bitgen.release(py).is_ok());
Ok(())
})
}
}
Loading