Skip to content

Commit

Permalink
Auto merge of #562 - Urgau:new-get-many-mut, r=Amanieu
Browse files Browse the repository at this point in the history
Change signature of `get_many_mut` APIs

This PR changes the signature and contract of the `get_many_mut` APIs by

1. panicking on overlapping keys
2. returning an array of Option rather than an Option of array.

This was asked by T-libs-api in rust-lang/rust#97601 (comment) regarding the corresponding std `HashMap` functions.
  • Loading branch information
bors committed Oct 1, 2024
2 parents 7cf51ea + d50e3b2 commit edd22e1
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 90 deletions.
154 changes: 103 additions & 51 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1467,8 +1467,11 @@ where
/// Attempts to get mutable references to `N` values in the map at once.
///
/// Returns an array of length `N` with the results of each query. For soundness, at most one
/// mutable reference will be returned to any value. `None` will be returned if any of the
/// keys are duplicates or missing.
/// mutable reference will be returned to any value. `None` will be used if the key is missing.
///
/// # Panics
///
/// Panics if any keys are overlapping.
///
/// # Examples
///
Expand All @@ -1481,33 +1484,52 @@ where
/// libraries.insert("Herzogin-Anna-Amalia-Bibliothek".to_string(), 1691);
/// libraries.insert("Library of Congress".to_string(), 1800);
///
/// // Get Athenæum and Bodleian Library
/// let [Some(a), Some(b)] = libraries.get_many_mut([
/// "Athenæum",
/// "Bodleian Library",
/// ]) else { panic!() };
///
/// // Assert values of Athenæum and Library of Congress
/// let got = libraries.get_many_mut([
/// "Athenæum",
/// "Library of Congress",
/// ]);
/// assert_eq!(
/// got,
/// Some([
/// &mut 1807,
/// &mut 1800,
/// ]),
/// [
/// Some(&mut 1807),
/// Some(&mut 1800),
/// ],
/// );
///
/// // Missing keys result in None
/// let got = libraries.get_many_mut([
/// "Athenæum",
/// "New York Public Library",
/// ]);
/// assert_eq!(got, None);
/// assert_eq!(
/// got,
/// [
/// Some(&mut 1807),
/// None
/// ]
/// );
/// ```
///
/// ```should_panic
/// use hashbrown::HashMap;
///
/// // Duplicate keys result in None
/// let mut libraries = HashMap::new();
/// libraries.insert("Athenæum".to_string(), 1807);
///
/// // Duplicate keys panic!
/// let got = libraries.get_many_mut([
/// "Athenæum",
/// "Athenæum",
/// ]);
/// assert_eq!(got, None);
/// ```
pub fn get_many_mut<Q, const N: usize>(&mut self, ks: [&Q; N]) -> Option<[&'_ mut V; N]>
pub fn get_many_mut<Q, const N: usize>(&mut self, ks: [&Q; N]) -> [Option<&'_ mut V>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand All @@ -1517,8 +1539,8 @@ where
/// Attempts to get mutable references to `N` values in the map at once, without validating that
/// the values are unique.
///
/// Returns an array of length `N` with the results of each query. `None` will be returned if
/// any of the keys are missing.
/// Returns an array of length `N` with the results of each query. `None` will be used if
/// the key is missing.
///
/// For a safe alternative see [`get_many_mut`](`HashMap::get_many_mut`).
///
Expand All @@ -1540,29 +1562,37 @@ where
/// libraries.insert("Herzogin-Anna-Amalia-Bibliothek".to_string(), 1691);
/// libraries.insert("Library of Congress".to_string(), 1800);
///
/// let got = libraries.get_many_mut([
/// // SAFETY: The keys do not overlap.
/// let [Some(a), Some(b)] = (unsafe { libraries.get_many_unchecked_mut([
/// "Athenæum",
/// "Bodleian Library",
/// ]) }) else { panic!() };
///
/// // SAFETY: The keys do not overlap.
/// let got = unsafe { libraries.get_many_unchecked_mut([
/// "Athenæum",
/// "Library of Congress",
/// ]);
/// ]) };
/// assert_eq!(
/// got,
/// Some([
/// &mut 1807,
/// &mut 1800,
/// ]),
/// [
/// Some(&mut 1807),
/// Some(&mut 1800),
/// ],
/// );
///
/// // Missing keys result in None
/// let got = libraries.get_many_mut([
/// // SAFETY: The keys do not overlap.
/// let got = unsafe { libraries.get_many_unchecked_mut([
/// "Athenæum",
/// "New York Public Library",
/// ]);
/// assert_eq!(got, None);
/// ]) };
/// // Missing keys result in None
/// assert_eq!(got, [Some(&mut 1807), None]);
/// ```
pub unsafe fn get_many_unchecked_mut<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[&'_ mut V; N]>
) -> [Option<&'_ mut V>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand All @@ -1574,8 +1604,11 @@ where
/// references to the corresponding keys.
///
/// Returns an array of length `N` with the results of each query. For soundness, at most one
/// mutable reference will be returned to any value. `None` will be returned if any of the keys
/// are duplicates or missing.
/// mutable reference will be returned to any value. `None` will be used if the key is missing.
///
/// # Panics
///
/// Panics if any keys are overlapping.
///
/// # Examples
///
Expand All @@ -1594,30 +1627,37 @@ where
/// ]);
/// assert_eq!(
/// got,
/// Some([
/// (&"Bodleian Library".to_string(), &mut 1602),
/// (&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691),
/// ]),
/// [
/// Some((&"Bodleian Library".to_string(), &mut 1602)),
/// Some((&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691)),
/// ],
/// );
/// // Missing keys result in None
/// let got = libraries.get_many_key_value_mut([
/// "Bodleian Library",
/// "Gewandhaus",
/// ]);
/// assert_eq!(got, None);
/// assert_eq!(got, [Some((&"Bodleian Library".to_string(), &mut 1602)), None]);
/// ```
///
/// ```should_panic
/// use hashbrown::HashMap;
///
/// let mut libraries = HashMap::new();
/// libraries.insert("Bodleian Library".to_string(), 1602);
/// libraries.insert("Herzogin-Anna-Amalia-Bibliothek".to_string(), 1691);
///
/// // Duplicate keys result in None
/// // Duplicate keys result in panic!
/// let got = libraries.get_many_key_value_mut([
/// "Bodleian Library",
/// "Herzogin-Anna-Amalia-Bibliothek",
/// "Herzogin-Anna-Amalia-Bibliothek",
/// ]);
/// assert_eq!(got, None);
/// ```
pub fn get_many_key_value_mut<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[(&'_ K, &'_ mut V); N]>
) -> [Option<(&'_ K, &'_ mut V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand Down Expand Up @@ -1657,30 +1697,36 @@ where
/// ]);
/// assert_eq!(
/// got,
/// Some([
/// (&"Bodleian Library".to_string(), &mut 1602),
/// (&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691),
/// ]),
/// [
/// Some((&"Bodleian Library".to_string(), &mut 1602)),
/// Some((&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691)),
/// ],
/// );
/// // Missing keys result in None
/// let got = libraries.get_many_key_value_mut([
/// "Bodleian Library",
/// "Gewandhaus",
/// ]);
/// assert_eq!(got, None);
/// assert_eq!(
/// got,
/// [
/// Some((&"Bodleian Library".to_string(), &mut 1602)),
/// None,
/// ],
/// );
/// ```
pub unsafe fn get_many_key_value_unchecked_mut<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[(&'_ K, &'_ mut V); N]>
) -> [Option<(&'_ K, &'_ mut V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
self.get_many_unchecked_mut_inner(ks)
.map(|res| res.map(|(k, v)| (&*k, v)))
}

fn get_many_mut_inner<Q, const N: usize>(&mut self, ks: [&Q; N]) -> Option<[&'_ mut (K, V); N]>
fn get_many_mut_inner<Q, const N: usize>(&mut self, ks: [&Q; N]) -> [Option<&'_ mut (K, V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand All @@ -1692,7 +1738,7 @@ where
unsafe fn get_many_unchecked_mut_inner<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[&'_ mut (K, V); N]>
) -> [Option<&'_ mut (K, V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand Down Expand Up @@ -5937,33 +5983,39 @@ mod test_map {
}

#[test]
fn test_get_each_mut() {
fn test_get_many_mut() {
let mut map = HashMap::new();
map.insert("foo".to_owned(), 0);
map.insert("bar".to_owned(), 10);
map.insert("baz".to_owned(), 20);
map.insert("qux".to_owned(), 30);

let xs = map.get_many_mut(["foo", "qux"]);
assert_eq!(xs, Some([&mut 0, &mut 30]));
assert_eq!(xs, [Some(&mut 0), Some(&mut 30)]);

let xs = map.get_many_mut(["foo", "dud"]);
assert_eq!(xs, None);

let xs = map.get_many_mut(["foo", "foo"]);
assert_eq!(xs, None);
assert_eq!(xs, [Some(&mut 0), None]);

let ys = map.get_many_key_value_mut(["bar", "baz"]);
assert_eq!(
ys,
Some([(&"bar".to_owned(), &mut 10), (&"baz".to_owned(), &mut 20),]),
[
Some((&"bar".to_owned(), &mut 10)),
Some((&"baz".to_owned(), &mut 20))
],
);

let ys = map.get_many_key_value_mut(["bar", "dip"]);
assert_eq!(ys, None);
assert_eq!(ys, [Some((&"bar".to_string(), &mut 10)), None]);
}

#[test]
#[should_panic = "duplicate keys found"]
fn test_get_many_mut_duplicate() {
let mut map = HashMap::new();
map.insert("foo".to_owned(), 0);

let ys = map.get_many_key_value_mut(["baz", "baz"]);
assert_eq!(ys, None);
let _xs = map.get_many_mut(["foo", "foo"]);
}

#[test]
Expand Down
45 changes: 22 additions & 23 deletions src/raw/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::alloc::alloc::{handle_alloc_error, Layout};
use crate::scopeguard::{guard, ScopeGuard};
use crate::TryReserveError;
use core::array;
use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::mem;
use core::mem::MaybeUninit;
use core::ptr::NonNull;
use core::{hint, ptr};

Expand Down Expand Up @@ -484,6 +484,13 @@ impl<T> Bucket<T> {
}
}

/// Acquires the underlying non-null pointer `*mut T` to `data`.
#[inline]
fn as_non_null(&self) -> NonNull<T> {
// SAFETY: `self.ptr` is already a `NonNull`
unsafe { NonNull::new_unchecked(self.as_ptr()) }
}

/// Create a new [`Bucket`] that is offset from the `self` by the given
/// `offset`. The pointer calculation is performed by calculating the
/// offset from `self` pointer (convenience for `self.ptr.as_ptr().sub(offset)`).
Expand Down Expand Up @@ -1291,48 +1298,40 @@ impl<T, A: Allocator> RawTable<T, A> {
&mut self,
hashes: [u64; N],
eq: impl FnMut(usize, &T) -> bool,
) -> Option<[&'_ mut T; N]> {
) -> [Option<&'_ mut T>; N] {
unsafe {
let ptrs = self.get_many_mut_pointers(hashes, eq)?;
let ptrs = self.get_many_mut_pointers(hashes, eq);

for (i, &cur) in ptrs.iter().enumerate() {
if ptrs[..i].iter().any(|&prev| ptr::eq::<T>(prev, cur)) {
return None;
for (i, cur) in ptrs.iter().enumerate() {
if cur.is_some() && ptrs[..i].contains(cur) {
panic!("duplicate keys found");
}
}
// All bucket are distinct from all previous buckets so we're clear to return the result
// of the lookup.

// TODO use `MaybeUninit::array_assume_init` here instead once that's stable.
Some(mem::transmute_copy(&ptrs))
ptrs.map(|ptr| ptr.map(|mut ptr| ptr.as_mut()))
}
}

pub unsafe fn get_many_unchecked_mut<const N: usize>(
&mut self,
hashes: [u64; N],
eq: impl FnMut(usize, &T) -> bool,
) -> Option<[&'_ mut T; N]> {
let ptrs = self.get_many_mut_pointers(hashes, eq)?;
Some(mem::transmute_copy(&ptrs))
) -> [Option<&'_ mut T>; N] {
let ptrs = self.get_many_mut_pointers(hashes, eq);
ptrs.map(|ptr| ptr.map(|mut ptr| ptr.as_mut()))
}

unsafe fn get_many_mut_pointers<const N: usize>(
&mut self,
hashes: [u64; N],
mut eq: impl FnMut(usize, &T) -> bool,
) -> Option<[*mut T; N]> {
// TODO use `MaybeUninit::uninit_array` here instead once that's stable.
let mut outs: MaybeUninit<[*mut T; N]> = MaybeUninit::uninit();
let outs_ptr = outs.as_mut_ptr();

for (i, &hash) in hashes.iter().enumerate() {
let cur = self.find(hash, |k| eq(i, k))?;
*(*outs_ptr).get_unchecked_mut(i) = cur.as_mut();
}

// TODO use `MaybeUninit::array_assume_init` here instead once that's stable.
Some(outs.assume_init())
) -> [Option<NonNull<T>>; N] {
array::from_fn(|i| {
self.find(hashes[i], |k| eq(i, k))
.map(|cur| cur.as_non_null())
})
}

/// Returns the number of elements the map can hold without reallocating.
Expand Down
Loading

0 comments on commit edd22e1

Please sign in to comment.