diff --git a/umbral-pre-python/Cargo.toml b/umbral-pre-python/Cargo.toml index 8b9f21ff..3b330329 100644 --- a/umbral-pre-python/Cargo.toml +++ b/umbral-pre-python/Cargo.toml @@ -11,3 +11,4 @@ crate-type = ["cdylib"] pyo3 = { version = "0.13", features = ["extension-module"] } umbral-pre = { path = "../umbral-pre" } generic-array = "0.14" +hex = "0.4" diff --git a/umbral-pre-python/src/lib.rs b/umbral-pre-python/src/lib.rs index 990cdfef..2c463371 100644 --- a/umbral-pre-python/src/lib.rs +++ b/umbral-pre-python/src/lib.rs @@ -3,7 +3,7 @@ use pyo3::create_exception; use pyo3::exceptions::{PyException, PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::pyclass::PyClass; -use pyo3::types::PyBytes; +use pyo3::types::{PyBytes, PyUnicode}; use pyo3::wrap_pyfunction; use pyo3::PyObjectProtocol; @@ -42,6 +42,27 @@ fn from_bytes, U: SerializableToArray>(bytes: &[u8] }) } +fn hash, U: SerializableToArray>(obj: &T) -> PyResult { + let serialized = obj.as_backend().to_array(); + + // call `hash((class_name, bytes(obj)))` + Python::with_gil(|py| { + let builtins = PyModule::import(py, "builtins")?; + let arg1 = PyUnicode::new(py, T::name()); + let arg2: PyObject = PyBytes::new(py, serialized.as_slice()).into(); + builtins.getattr("hash")?.call1(((arg1, arg2),))?.extract() + }) +} + +// For some reason this lint is not recognized in Rust 1.46 (the one in CI) +// remove when CI is updated to a newer Rust version. +#[allow(clippy::unknown_clippy_lints)] +#[allow(clippy::unnecessary_wraps)] // Don't want to wrap it in Ok() on every call +fn hexstr, U: SerializableToArray>(obj: &T) -> PyResult { + let hex_str = hex::encode(obj.as_backend().to_array().as_slice()); + Ok(format!("{}:{}", T::name(), &hex_str[0..16])) +} + fn richcmp + PyClass + PartialEq, U>( obj: &T, other: PyRef, @@ -103,6 +124,10 @@ impl PyObjectProtocol for SecretKey { fn __bytes__(&self) -> PyResult { to_bytes(self) } + + fn __str__(&self) -> PyResult { + Ok(format!("{}:...", Self::name())) + } } #[pyclass(module = "umbral")] @@ -163,6 +188,10 @@ impl PyObjectProtocol for SecretKeyFactory { fn __bytes__(&self) -> PyResult { to_bytes(self) } + + fn __str__(&self) -> PyResult { + Ok(format!("{}:...", Self::name())) + } } #[pyclass(module = "umbral")] @@ -209,9 +238,18 @@ impl PyObjectProtocol for PublicKey { fn __bytes__(&self) -> PyResult { to_bytes(self) } + + fn __hash__(&self) -> PyResult { + hash(self) + } + + fn __str__(&self) -> PyResult { + hexstr(self) + } } #[pyclass(module = "umbral")] +#[derive(PartialEq)] pub struct Capsule { backend: umbral_pre::Capsule, } @@ -240,9 +278,21 @@ impl Capsule { #[pyproto] impl PyObjectProtocol for Capsule { + fn __richcmp__(&self, other: PyRef, op: CompareOp) -> PyResult { + richcmp(self, other, op) + } + fn __bytes__(&self) -> PyResult { to_bytes(self) } + + fn __hash__(&self) -> PyResult { + hash(self) + } + + fn __str__(&self) -> PyResult { + hexstr(self) + } } #[pyfunction] @@ -289,6 +339,7 @@ pub fn decrypt_original( } #[pyclass(module = "umbral")] +#[derive(PartialEq)] pub struct KeyFrag { backend: umbral_pre::KeyFrag, } @@ -330,9 +381,21 @@ impl KeyFrag { #[pyproto] impl PyObjectProtocol for KeyFrag { + fn __richcmp__(&self, other: PyRef, op: CompareOp) -> PyResult { + richcmp(self, other, op) + } + fn __bytes__(&self) -> PyResult { to_bytes(self) } + + fn __hash__(&self) -> PyResult { + hash(self) + } + + fn __str__(&self) -> PyResult { + hexstr(self) + } } #[allow(clippy::too_many_arguments)] @@ -364,7 +427,7 @@ pub fn generate_kfrags( } #[pyclass(module = "umbral")] -#[derive(Clone)] +#[derive(Clone, PartialEq)] pub struct CapsuleFrag { backend: umbral_pre::CapsuleFrag, } @@ -410,9 +473,21 @@ impl CapsuleFrag { #[pyproto] impl PyObjectProtocol for CapsuleFrag { + fn __richcmp__(&self, other: PyRef, op: CompareOp) -> PyResult { + richcmp(self, other, op) + } + fn __bytes__(&self) -> PyResult { to_bytes(self) } + + fn __hash__(&self) -> PyResult { + hash(self) + } + + fn __str__(&self) -> PyResult { + hexstr(self) + } } #[pyfunction]