|
| 1 | +//! Attempt at flexible symbol interning, allowing to intern and free strings at runtime while also |
| 2 | +//! supporting |
| 3 | +
|
| 4 | +use std::{ |
| 5 | + borrow::Borrow, |
| 6 | + fmt, |
| 7 | + hash::{BuildHasherDefault, Hash, Hasher}, |
| 8 | + mem, |
| 9 | + ptr::NonNull, |
| 10 | + sync::OnceLock, |
| 11 | +}; |
| 12 | + |
| 13 | +use dashmap::{DashMap, SharedValue}; |
| 14 | +use hashbrown::{hash_map::RawEntryMut, HashMap}; |
| 15 | +use rustc_hash::FxHasher; |
| 16 | +use sptr::Strict; |
| 17 | +use triomphe::Arc; |
| 18 | + |
| 19 | +pub mod symbols; |
| 20 | + |
| 21 | +// some asserts for layout compatibility |
| 22 | +const _: () = assert!(std::mem::size_of::<Box<str>>() == std::mem::size_of::<&str>()); |
| 23 | +const _: () = assert!(std::mem::align_of::<Box<str>>() == std::mem::align_of::<&str>()); |
| 24 | + |
| 25 | +const _: () = assert!(std::mem::size_of::<Arc<Box<str>>>() == std::mem::size_of::<&&str>()); |
| 26 | +const _: () = assert!(std::mem::align_of::<Arc<Box<str>>>() == std::mem::align_of::<&&str>()); |
| 27 | + |
| 28 | +/// A pointer that points to a pointer to a `str`, it may be backed as a `&'static &'static str` or |
| 29 | +/// `Arc<Box<str>>` but its size is that of a thin pointer. The active variant is encoded as a tag |
| 30 | +/// in the LSB of the alignment niche. |
| 31 | +#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)] |
| 32 | +struct TaggedArcPtr { |
| 33 | + packed: NonNull<*const str>, |
| 34 | +} |
| 35 | + |
| 36 | +unsafe impl Send for TaggedArcPtr {} |
| 37 | +unsafe impl Sync for TaggedArcPtr {} |
| 38 | + |
| 39 | +impl TaggedArcPtr { |
| 40 | + const BOOL_BITS: usize = true as usize; |
| 41 | + |
| 42 | + const fn non_arc(r: &&str) -> Self { |
| 43 | + Self { |
| 44 | + // SAFETY: The pointer is non-null as it is derived from a reference |
| 45 | + // Ideally we would call out to `pack_arc` but for a `false` tag, unfortunately the |
| 46 | + // packing stuff requires reading out the pointer to an integer which is not supported |
| 47 | + // in const contexts, so here we make use of the fact that for the non-arc version the |
| 48 | + // tag is false (0) and thus does not need touching the actual pointer value.ext) |
| 49 | + packed: unsafe { |
| 50 | + NonNull::new_unchecked((r as *const &str).cast::<*const str>().cast_mut()) |
| 51 | + }, |
| 52 | + } |
| 53 | + } |
| 54 | + |
| 55 | + fn arc(arc: Arc<Box<str>>) -> Self { |
| 56 | + Self { |
| 57 | + packed: Self::pack_arc( |
| 58 | + // Safety: `Arc::into_raw` always returns a non null pointer |
| 59 | + unsafe { NonNull::new_unchecked(Arc::into_raw(arc).cast_mut().cast()) }, |
| 60 | + ), |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + /// Retrieves the tag. |
| 65 | + #[inline] |
| 66 | + pub(crate) fn try_as_arc_owned(self) -> Option<Arc<Box<str>>> { |
| 67 | + // Unpack the tag from the alignment niche |
| 68 | + let tag = Strict::addr(self.packed.as_ptr()) & Self::BOOL_BITS; |
| 69 | + if tag != 0 { |
| 70 | + // Safety: We checked that the tag is non-zero -> true, so we are pointing to the data offset of an `Arc` |
| 71 | + Some(unsafe { Arc::from_raw(self.pointer().as_ptr().cast::<Box<str>>()) }) |
| 72 | + } else { |
| 73 | + None |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + #[inline] |
| 78 | + const fn pack_arc(ptr: NonNull<*const str>) -> NonNull<*const str> { |
| 79 | + let packed_tag = true as usize; |
| 80 | + |
| 81 | + // can't use this strict provenance stuff here due to trait methods not being const |
| 82 | + // unsafe { |
| 83 | + // // Safety: The pointer is derived from a non-null |
| 84 | + // NonNull::new_unchecked(Strict::map_addr(ptr.as_ptr(), |addr| { |
| 85 | + // // Safety: |
| 86 | + // // - The pointer is `NonNull` => it's address is `NonZero<usize>` |
| 87 | + // // - `P::BITS` least significant bits are always zero (`Pointer` contract) |
| 88 | + // // - `T::BITS <= P::BITS` (from `Self::ASSERTION`) |
| 89 | + // // |
| 90 | + // // Thus `addr >> T::BITS` is guaranteed to be non-zero. |
| 91 | + // // |
| 92 | + // // `{non_zero} | packed_tag` can't make the value zero. |
| 93 | + |
| 94 | + // (addr >> Self::BOOL_BITS) | packed_tag |
| 95 | + // })) |
| 96 | + // } |
| 97 | + // so what follows is roughly what the above looks like but inlined |
| 98 | + |
| 99 | + let self_addr = unsafe { core::mem::transmute::<*const _, usize>(ptr.as_ptr()) }; |
| 100 | + let addr = self_addr | packed_tag; |
| 101 | + let dest_addr = addr as isize; |
| 102 | + let offset = dest_addr.wrapping_sub(self_addr as isize); |
| 103 | + |
| 104 | + // SAFETY: The resulting pointer is guaranteed to be NonNull as we only modify the niche bytes |
| 105 | + unsafe { NonNull::new_unchecked(ptr.as_ptr().cast::<u8>().wrapping_offset(offset).cast()) } |
| 106 | + } |
| 107 | + |
| 108 | + #[inline] |
| 109 | + pub(crate) fn pointer(self) -> NonNull<*const str> { |
| 110 | + // SAFETY: The resulting pointer is guaranteed to be NonNull as we only modify the niche bytes |
| 111 | + unsafe { |
| 112 | + NonNull::new_unchecked(Strict::map_addr(self.packed.as_ptr(), |addr| { |
| 113 | + addr & !Self::BOOL_BITS |
| 114 | + })) |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + #[inline] |
| 119 | + pub(crate) fn as_str(&self) -> &str { |
| 120 | + // SAFETY: We always point to a pointer to a str no matter what variant is active |
| 121 | + unsafe { *self.pointer().as_ptr().cast::<&str>() } |
| 122 | + } |
| 123 | +} |
| 124 | + |
| 125 | +#[derive(PartialEq, Eq, Hash, Clone, Debug)] |
| 126 | +pub struct Symbol { |
| 127 | + repr: TaggedArcPtr, |
| 128 | +} |
| 129 | +const _: () = assert!(std::mem::size_of::<Symbol>() == std::mem::size_of::<NonNull<()>>()); |
| 130 | +const _: () = assert!(std::mem::align_of::<Symbol>() == std::mem::align_of::<NonNull<()>>()); |
| 131 | + |
| 132 | +static MAP: OnceLock<DashMap<SymbolProxy, (), BuildHasherDefault<FxHasher>>> = OnceLock::new(); |
| 133 | + |
| 134 | +impl Symbol { |
| 135 | + pub fn intern(s: &str) -> Self { |
| 136 | + let (mut shard, hash) = Self::select_shard(s); |
| 137 | + // Atomically, |
| 138 | + // - check if `obj` is already in the map |
| 139 | + // - if so, copy out its entry, conditionally bumping the backing Arc and return it |
| 140 | + // - if not, put it into a box and then into an Arc, insert it, bump the ref-count and return the copy |
| 141 | + // This needs to be atomic (locking the shard) to avoid races with other thread, which could |
| 142 | + // insert the same object between us looking it up and inserting it. |
| 143 | + match shard.raw_entry_mut().from_key_hashed_nocheck(hash, s) { |
| 144 | + RawEntryMut::Occupied(occ) => Self { repr: increase_arc_refcount(occ.key().0) }, |
| 145 | + RawEntryMut::Vacant(vac) => Self { |
| 146 | + repr: increase_arc_refcount( |
| 147 | + vac.insert_hashed_nocheck( |
| 148 | + hash, |
| 149 | + SymbolProxy(TaggedArcPtr::arc(Arc::new(Box::<str>::from(s)))), |
| 150 | + SharedValue::new(()), |
| 151 | + ) |
| 152 | + .0 |
| 153 | + .0, |
| 154 | + ), |
| 155 | + }, |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + pub fn as_str(&self) -> &str { |
| 160 | + self.repr.as_str() |
| 161 | + } |
| 162 | + |
| 163 | + #[inline] |
| 164 | + fn select_shard( |
| 165 | + s: &str, |
| 166 | + ) -> ( |
| 167 | + dashmap::RwLockWriteGuard< |
| 168 | + 'static, |
| 169 | + HashMap<SymbolProxy, SharedValue<()>, BuildHasherDefault<FxHasher>>, |
| 170 | + >, |
| 171 | + u64, |
| 172 | + ) { |
| 173 | + let storage = MAP.get_or_init(symbols::prefill); |
| 174 | + let hash = { |
| 175 | + let mut hasher = std::hash::BuildHasher::build_hasher(storage.hasher()); |
| 176 | + s.hash(&mut hasher); |
| 177 | + hasher.finish() |
| 178 | + }; |
| 179 | + let shard_idx = storage.determine_shard(hash as usize); |
| 180 | + let shard = &storage.shards()[shard_idx]; |
| 181 | + (shard.write(), hash) |
| 182 | + } |
| 183 | + |
| 184 | + #[cold] |
| 185 | + fn drop_slow(arc: &Arc<Box<str>>) { |
| 186 | + let (mut shard, hash) = Self::select_shard(arc); |
| 187 | + |
| 188 | + if Arc::count(arc) != 2 { |
| 189 | + // Another thread has interned another copy |
| 190 | + return; |
| 191 | + } |
| 192 | + |
| 193 | + match shard.raw_entry_mut().from_key_hashed_nocheck::<str>(hash, arc.as_ref()) { |
| 194 | + RawEntryMut::Occupied(occ) => occ.remove_entry(), |
| 195 | + RawEntryMut::Vacant(_) => unreachable!(), |
| 196 | + } |
| 197 | + .0 |
| 198 | + .0 |
| 199 | + .try_as_arc_owned() |
| 200 | + .unwrap(); |
| 201 | + |
| 202 | + // Shrink the backing storage if the shard is less than 50% occupied. |
| 203 | + if shard.len() * 2 < shard.capacity() { |
| 204 | + shard.shrink_to_fit(); |
| 205 | + } |
| 206 | + } |
| 207 | +} |
| 208 | + |
| 209 | +impl Drop for Symbol { |
| 210 | + #[inline] |
| 211 | + fn drop(&mut self) { |
| 212 | + let Some(arc) = self.repr.try_as_arc_owned() else { |
| 213 | + return; |
| 214 | + }; |
| 215 | + // When the last `Ref` is dropped, remove the object from the global map. |
| 216 | + if Arc::count(&arc) == 2 { |
| 217 | + // Only `self` and the global map point to the object. |
| 218 | + |
| 219 | + Self::drop_slow(&arc); |
| 220 | + } |
| 221 | + // decrement the ref count |
| 222 | + drop(arc); |
| 223 | + } |
| 224 | +} |
| 225 | + |
| 226 | +fn increase_arc_refcount(repr: TaggedArcPtr) -> TaggedArcPtr { |
| 227 | + let Some(arc) = repr.try_as_arc_owned() else { |
| 228 | + return repr; |
| 229 | + }; |
| 230 | + // increase the ref count |
| 231 | + mem::forget(arc.clone()); |
| 232 | + mem::forget(arc); |
| 233 | + repr |
| 234 | +} |
| 235 | + |
| 236 | +impl fmt::Display for Symbol { |
| 237 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 238 | + self.as_str().fmt(f) |
| 239 | + } |
| 240 | +} |
| 241 | + |
| 242 | +// only exists so we can use `from_key_hashed_nocheck` with a &str |
| 243 | +#[derive(Debug, PartialEq, Eq)] |
| 244 | +struct SymbolProxy(TaggedArcPtr); |
| 245 | + |
| 246 | +impl Hash for SymbolProxy { |
| 247 | + fn hash<H: Hasher>(&self, state: &mut H) { |
| 248 | + self.0.as_str().hash(state); |
| 249 | + } |
| 250 | +} |
| 251 | + |
| 252 | +impl Borrow<str> for SymbolProxy { |
| 253 | + fn borrow(&self) -> &str { |
| 254 | + self.0.as_str() |
| 255 | + } |
| 256 | +} |
| 257 | + |
| 258 | +#[cfg(test)] |
| 259 | +mod tests { |
| 260 | + use super::*; |
| 261 | + |
| 262 | + #[test] |
| 263 | + fn smoke_test() { |
| 264 | + Symbol::intern("isize"); |
| 265 | + let base_len = MAP.get().unwrap().len(); |
| 266 | + let hello = Symbol::intern("hello"); |
| 267 | + let world = Symbol::intern("world"); |
| 268 | + let bang = Symbol::intern("!"); |
| 269 | + let q = Symbol::intern("?"); |
| 270 | + assert_eq!(MAP.get().unwrap().len(), base_len + 4); |
| 271 | + let bang2 = Symbol::intern("!"); |
| 272 | + assert_eq!(MAP.get().unwrap().len(), base_len + 4); |
| 273 | + drop(bang2); |
| 274 | + assert_eq!(MAP.get().unwrap().len(), base_len + 4); |
| 275 | + drop(q); |
| 276 | + assert_eq!(MAP.get().unwrap().len(), base_len + 3); |
| 277 | + let default = Symbol::intern("default"); |
| 278 | + assert_eq!(MAP.get().unwrap().len(), base_len + 3); |
| 279 | + assert_eq!( |
| 280 | + "hello default world!", |
| 281 | + format!("{} {} {}{}", hello.as_str(), default.as_str(), world.as_str(), bang.as_str()) |
| 282 | + ); |
| 283 | + drop(default); |
| 284 | + assert_eq!( |
| 285 | + "hello world!", |
| 286 | + format!("{} {}{}", hello.as_str(), world.as_str(), bang.as_str()) |
| 287 | + ); |
| 288 | + drop(hello); |
| 289 | + drop(world); |
| 290 | + drop(bang); |
| 291 | + assert_eq!(MAP.get().unwrap().len(), base_len); |
| 292 | + } |
| 293 | +} |
0 commit comments