Skip to content

Add contextvars types #5022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/sealed.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::types::{
PyBool, PyByteArray, PyBytes, PyCapsule, PyComplex, PyDict, PyFloat, PyFrozenSet, PyList,
PyMapping, PyMappingProxy, PyModule, PySequence, PySet, PySlice, PyString, PyTraceback,
PyTuple, PyType, PyWeakref, PyWeakrefProxy, PyWeakrefReference,
PyBool, PyByteArray, PyBytes, PyCapsule, PyComplex, PyContext, PyContextVar, PyContextToken,
PyDict, PyFloat, PyFrozenSet, PyList, PyMapping, PyMappingProxy, PyModule, PySequence, PySet,
PySlice, PyString, PyTraceback, PyTuple, PyType, PyWeakref, PyWeakrefProxy, PyWeakrefReference,
};
use crate::{ffi, Bound, PyAny, PyResult};

Expand All @@ -28,6 +28,9 @@ impl Sealed for Bound<'_, PyByteArray> {}
impl Sealed for Bound<'_, PyBytes> {}
impl Sealed for Bound<'_, PyCapsule> {}
impl Sealed for Bound<'_, PyComplex> {}
impl Sealed for Bound<'_, PyContext> {}
impl Sealed for Bound<'_, PyContextVar> {}
impl Sealed for Bound<'_, PyContextToken> {}
impl Sealed for Bound<'_, PyDict> {}
impl Sealed for Bound<'_, PyFloat> {}
impl Sealed for Bound<'_, PyFrozenSet> {}
Expand Down
250 changes: 250 additions & 0 deletions src/types/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
//! Safe Rust wrappers for types defined in the Python `contextvars` library
//!
//! For more details about these types, see the [Python
//! documentation](https://docs.python.org/3/library/contextvars.html)

use crate::err::PyResult;
use crate::sync::GILOnceCell;
use crate::{ffi, Py, PyTypeInfo};
use crate::ffi_ptr_ext::FfiPtrExt;
use crate::py_result_ext::PyResultExt;
use crate::{Bound, BoundObject, IntoPyObject, PyAny, PyErr, Python};
use std::ffi::CStr;
use std::ptr;

use super::{PyAnyMethods, PyString};

/// Implementation of functionality for [`PyContext`].
///
/// These methods are defined for the `Bound<'py, PyContext>` smart pointer, so to use method call
/// syntax these methods are separated into a trait, because stable Rust does not yet support
/// `arbitrary_self_types`.
#[doc(alias = "PyContext")]
pub trait PyContextMethods<'py>: crate::sealed::Sealed {
/// Return a shallow copy of the context object.
fn copy(&self) -> PyResult<Bound<'py, PyContext>>;
/// Set ctx as the current context for the current thread
fn enter(&self) -> PyResult<()>;
/// Deactivate the context and restore the previous context as the current context for the current thread
fn exit(&self) -> PyResult<()>;
}

/// Implementation of functionality for [`PyContextVar`].
///
/// These methods are defined for the `Bound<'py, PyContextVar>` smart pointer, so to use method call
/// syntax these methods are separated into a trait, because stable Rust does not yet support
/// `arbitrary_self_types`.
#[doc(alias = "PyContextVar")]
pub trait PyContextVarMethods<'py>: crate::sealed::Sealed {
/// The name of the variable.
fn name(&self) -> Bound<'py, PyString>;

/// Return a value for the context variable for the current context.
fn get(&self) -> PyResult<Option<Bound<'py, PyAny>>>;

/// Return a value for the context variable for the current context.
fn get_or_default(&self, default: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyAny>>;

/// Call to set a new value for the context variable in the current context.
///
/// Returns a Token object that can be used to restore the variable to its previous value via the ContextVar.reset() method.
fn set<T>(&self, value: Bound<'py, T>) -> PyResult<Bound<'py, PyContextToken>>;

/// Reset the context variable to the value it had before the ContextVar.set() that created the token was used.
fn reset(&self, token: Bound<'py, PyContextToken>) -> PyResult<()>;
}

/// Implementation of functionality for [`PyContextToken`].
///
/// These methods are defined for the `Bound<'py, PyContextToken>` smart pointer, so to use method call
/// syntax these methods are separated into a trait, because stable Rust does not yet support
/// `arbitrary_self_types`.
#[doc(alias = "PyContextToken")]
pub trait PyContextTokenMethods<'py>: crate::sealed::Sealed {
/// The ContextVar object that created this token
fn var(&self) -> PyResult<Bound<'py, PyContextVar>>;

/// Set to the value the variable had before the ContextVar.set() method call that created the token.
///
/// It returns `None`` if the variable was not set before the call.
fn old_value(&self) -> PyResult<Option<Bound<'py, PyAny>>>;
}

/// A mapping of ContextVars to their values.
///
/// Values of this type are accessed via PyO3's smart pointers, e.g. as
/// [`Py<PyContext>`][crate::Py] or [`Bound<'py, PyContext>`][Bound].
#[repr(transparent)]
pub struct PyContext(PyAny);
pyobject_native_type_core!(
PyContext,
pyobject_native_static_type_object!(ffi::PyContext_Type),
#module=Some("contextvars"),
#checkfunction=ffi::PyContext_CheckExact
);

impl PyContext {
/// Create a new empty context object
pub fn new(py: Python<'_>) -> PyResult<Bound<'_, PyContext>> {
unsafe {
ffi::PyContext_New()
.assume_owned_or_err(py)
.downcast_into_unchecked()
}
}

/// Returns a copy of the current Context object.
pub fn copy_current(py: Python<'_>) -> PyResult<Bound<'_, PyContext>> {
unsafe {
ffi::PyContext_CopyCurrent()
.assume_owned_or_err(py)
.downcast_into_unchecked()
}
}
}

impl<'py> PyContextMethods<'py> for Bound<'py, PyContext> {
fn copy(&self) -> PyResult<Bound<'py, PyContext>> {
unsafe {
ffi::PyContext_Copy(self.as_ptr())
.assume_owned_or_err(self.py())
.downcast_into_unchecked()
}
}

fn enter(&self) -> PyResult<()> {
let r = unsafe { ffi::PyContext_Enter(self.as_ptr()) };
if r == 0 {
Ok(())
} else {
Err(PyErr::fetch(self.py()))
}
Comment on lines +117 to +121
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use error_on_minusone for this (same for the other instances).

}

fn exit(&self) -> PyResult<()> {
let r = unsafe { ffi::PyContext_Exit(self.as_ptr()) };
if r == 0 {
Ok(())
} else {
Err(PyErr::fetch(self.py()))
}
}
}

/// Bindings around `contextvars.ContextVar`.
///
/// Values of this type are accessed via PyO3's smart pointers, e.g. as
/// [`Py<PyContextVar>`][crate::Py] or [`Bound<'py, PyContextVar>`][Bound].
#[repr(transparent)]
pub struct PyContextVar(PyAny);
pyobject_native_type_core!(
PyContextVar,
pyobject_native_static_type_object!(ffi::PyContextVar_Type),
#module=Some("contextvars"),
#checkfunction=ffi::PyContextVar_CheckExact
);

impl PyContextVar {
/// Create new ContextVar with no default
pub fn new<'py>(py: Python<'py>, name: &'static CStr) -> PyResult<Bound<'py, PyContextVar>> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need the 'static lifetime:

Suggested change
pub fn new<'py>(py: Python<'py>, name: &'static CStr) -> PyResult<Bound<'py, PyContextVar>> {
pub fn new<'py>(py: Python<'py>, name: &CStr) -> PyResult<Bound<'py, PyContextVar>> {

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need the 'static lifetime:

That would be both reasonable and convenient, but I can't find the required lifetime on the cpython docs so I figured 'static was safe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string immediately gets converted into a python string so a 'static lifetime is not needed.

https://github.com/python/cpython/blob/bab1398a47f6d0cfc1be70497f306874c749ef7c/Python/context.c#L260-L270

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not read the cpython source code and use it to make conclusions about what requirements need to be upheld or not. That's implementation details, after all.

However, if a 'static lifetime is necessary the docs will mention it. For example at PyMemberDef.

unsafe {
ffi::PyContextVar_New(name.as_ptr(), ptr::null_mut())
.assume_owned_or_err(py)
.downcast_into_unchecked()
}
}

/// Create new ContextVar with default value
pub fn with_default<'py, D: IntoPyObject<'py>>(py: Python<'py>, name: &CStr, default: D) -> PyResult<Bound<'py, PyContextVar>> {
let def = default.into_pyobject(py).map_err(Into::into)?;
unsafe {
ffi::PyContextVar_New(name.as_ptr(), def.as_ptr())
.assume_owned_or_err(py)
.downcast_into_unchecked()
}
}
}

impl<'py> PyContextVarMethods<'py> for Bound<'py, PyContextVar> {
fn name(&self) -> Bound<'py, PyString> {
self.getattr("name")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use https://docs.rs/pyo3/latest/pyo3/macro.intern.html for string literals like these.

.unwrap()
.downcast_into_exact::<PyString>()
.unwrap()
}

fn get(&self) -> PyResult<Option<Bound<'py, PyAny>>> {
let mut value = ptr::null_mut();
let r = unsafe { ffi::PyContextVar_Get(self.as_ptr(), ptr::null_mut(), &mut value) };
if r == 0 {
Ok(unsafe { value.assume_owned_or_opt(self.py()) })
} else {
Err(PyErr::fetch(self.py()))
}
}

fn get_or_default(&self, default: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyAny>> {
let mut value = ptr::null_mut();
let r = unsafe { ffi::PyContextVar_Get(self.as_ptr(), default.as_ptr(), &mut value) };
if r == 0 {
Ok(unsafe { value.assume_owned(self.py()) })
} else {
Err(PyErr::fetch(self.py()))
}
}

fn set<T>(&self, value: Bound<'py, T>) -> PyResult<Bound<'py, PyContextToken>> {
unsafe {
ffi::PyContextVar_Set(self.as_ptr(), value.as_ptr())
.assume_owned_or_err(self.py())
.downcast_into_unchecked()
}
}

fn reset(&self, token: Bound<'py, PyContextToken>) -> PyResult<()> {
let r = unsafe { ffi::PyContextVar_Reset(self.as_ptr(), token.as_ptr()) };
if r == 0 {
Ok(())
} else {
Err(PyErr::fetch(self.py()))
}
}
}


/// Bindings around `contextvars.Token`.
///
/// Values of this type are accessed via PyO3's smart pointers, e.g. as
#[repr(transparent)]
pub struct PyContextToken(PyAny);
pyobject_native_type_core!(
PyContextToken,
pyobject_native_static_type_object!(ffi::PyContextToken_Type),
#module=Some("contextvars"),
#checkfunction=ffi::PyContextToken_CheckExact
);

impl<'py> PyContextTokenMethods<'py> for Bound<'py, PyContextToken> {
fn var(&self) -> PyResult<Bound<'py, PyContextVar>> {
self.getattr("var")
.downcast_into()
}

fn old_value(&self) -> PyResult<Option<Bound<'py, PyAny>>> {
let old_value = self.getattr("old_value")?;

// Check if token is missing
static TOKEN_MISSING: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
let missing = TOKEN_MISSING.get_or_init( self.py(), || {
PyContextToken::type_object(self.py())
.getattr("MISSING")
.expect("Unable to get contextvars.Token.MISSING")
.unbind()
});
Ok(if old_value.is(missing) {
None
} else {
Some(old_value)
})
}
}
5 changes: 5 additions & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ pub use self::capsule::{PyCapsule, PyCapsuleMethods};
#[cfg(all(not(Py_LIMITED_API), not(PyPy), not(GraalPy)))]
pub use self::code::PyCode;
pub use self::complex::{PyComplex, PyComplexMethods};
pub use self::context::{
PyContext, PyContextVar, PyContextToken,
PyContextMethods, PyContextVarMethods, PyContextTokenMethods,
};
#[cfg(not(Py_LIMITED_API))]
#[allow(deprecated)]
pub use self::datetime::{
Expand Down Expand Up @@ -255,3 +259,4 @@ pub(crate) mod traceback;
pub(crate) mod tuple;
pub(crate) mod typeobject;
pub(crate) mod weakref;
pub(crate) mod context;
71 changes: 71 additions & 0 deletions tests/test_contextvars.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#![cfg(not(Py_LIMITED_API))]

use pyo3::exceptions::PyRuntimeError;
use pyo3::types::{PyContext, PyContextMethods, PyContextToken, PyContextTokenMethods, PyContextVar, PyContextVarMethods};
use pyo3::prelude::*;
use pyo3_ffi::c_str;

#[test]
fn test_context() {
Python::with_gil(|py| {
let context = PyContext::new(py).unwrap();
assert!(context.is_instance_of::<PyContext>());
assert!(context.is_exact_instance_of::<PyContext>());
assert!(!context.is_instance_of::<PyContextVar>());
assert!(!context.is_exact_instance_of::<PyContextToken>());

// Copy
let context2 = context.copy().unwrap();
assert!(context2.is_exact_instance_of::<PyContext>());
assert!(!context.is(&context2));
});
}

#[test]
fn test_context_copycurrent() {
Python::with_gil(|py| {
let current_context = PyContext::copy_current(py).unwrap();
assert!(current_context.is_exact_instance_of::<PyContext>());

let current_context2 = PyContext::copy_current(py).unwrap();
assert!(!current_context.is(&current_context2));
});
}

#[test]
fn test_contextvar_new() {
Python::with_gil(|py| {
let cv = PyContextVar::new(py, c_str!("test")).unwrap();
assert!(cv.is_exact_instance_of::<PyContextVar>());

assert!(cv.get().unwrap().is_none());
});
}


#[test]
fn test_contextvar_set() {
Python::with_gil(|py| {
let cv = PyContextVar::new(py, c_str!("test")).unwrap();
assert!(cv.is_exact_instance_of::<PyContextVar>());

assert!(cv.get().unwrap().is_none());

let token = cv.set(1_u64.into_pyobject(py).unwrap()).unwrap();
assert!(token.is_exact_instance_of::<PyContextToken>());
assert!(token.old_value().unwrap().is_none());
assert!(token.var().unwrap().is(&cv));
assert_eq!(cv.get().unwrap().unwrap().extract::<u64>().unwrap(), 1);

// Reset to default state
cv.reset(token.clone()).unwrap();
assert!(cv.get().unwrap().is_none());

// Check that we can't reset twice
{
let reset_err = cv.reset(token).unwrap_err();
assert!(reset_err.is_instance_of::<PyRuntimeError>(py));
assert!(reset_err.to_string().ends_with(" has already been used once"));
}
});
}
Loading