Skip to content
This repository was archived by the owner on Nov 9, 2025. It is now read-only.

Commit 3122a9a

Browse files
committed
Optimise to_{lower,upper}case_smolstr
1 parent 81c8790 commit 3122a9a

File tree

1 file changed

+98
-6
lines changed

1 file changed

+98
-6
lines changed

src/lib.rs

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,17 @@ impl iter::FromIterator<char> for SmolStr {
233233
}
234234
}
235235

236-
fn from_char_iter(mut iter: impl Iterator<Item = char>) -> SmolStr {
237-
let (min_size, _) = iter.size_hint();
236+
#[inline]
237+
fn from_char_iter(iter: impl Iterator<Item = char>) -> SmolStr {
238+
from_buf_and_chars([0; _], 0, iter)
239+
}
240+
241+
fn from_buf_and_chars(
242+
mut buf: [u8; INLINE_CAP],
243+
buf_len: usize,
244+
mut iter: impl Iterator<Item = char>,
245+
) -> SmolStr {
246+
let min_size = iter.size_hint().0 + buf_len;
238247
if min_size > INLINE_CAP {
239248
let heap: String = iter.collect();
240249
if heap.len() <= INLINE_CAP {
@@ -243,8 +252,7 @@ fn from_char_iter(mut iter: impl Iterator<Item = char>) -> SmolStr {
243252
}
244253
return SmolStr(Repr::Heap(heap.into_boxed_str().into()));
245254
}
246-
let mut len = 0;
247-
let mut buf = [0u8; INLINE_CAP];
255+
let mut len = buf_len;
248256
while let Some(ch) = iter.next() {
249257
let size = ch.len_utf8();
250258
if size + len > INLINE_CAP {
@@ -634,12 +642,32 @@ pub trait StrExt: private::Sealed {
634642
impl StrExt for str {
635643
#[inline]
636644
fn to_lowercase_smolstr(&self) -> SmolStr {
637-
from_char_iter(self.chars().flat_map(|c| c.to_lowercase()))
645+
let len = self.len();
646+
if len <= INLINE_CAP {
647+
let (buf, rest) = inline_convert_while_ascii(self, u8::to_ascii_lowercase);
648+
from_buf_and_chars(
649+
buf,
650+
len - rest.len(),
651+
rest.chars().flat_map(|c| c.to_lowercase()),
652+
)
653+
} else {
654+
self.to_lowercase().into()
655+
}
638656
}
639657

640658
#[inline]
641659
fn to_uppercase_smolstr(&self) -> SmolStr {
642-
from_char_iter(self.chars().flat_map(|c| c.to_uppercase()))
660+
let len = self.len();
661+
if len <= INLINE_CAP {
662+
let (buf, rest) = inline_convert_while_ascii(self, u8::to_ascii_uppercase);
663+
from_buf_and_chars(
664+
buf,
665+
len - rest.len(),
666+
rest.chars().flat_map(|c| c.to_uppercase()),
667+
)
668+
} else {
669+
self.to_uppercase().into()
670+
}
643671
}
644672

645673
#[inline]
@@ -699,6 +727,70 @@ impl StrExt for str {
699727
}
700728
}
701729

730+
/// Inline version of std fn `convert_while_ascii`. `s` must have len <= 23.
731+
#[inline]
732+
fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_CAP], &str) {
733+
// Process the input in chunks of 16 bytes to enable auto-vectorization.
734+
// Previously the chunk size depended on the size of `usize`,
735+
// but on 32-bit platforms with sse or neon is also the better choice.
736+
// The only downside on other platforms would be a bit more loop-unrolling.
737+
const N: usize = 16;
738+
739+
debug_assert!(s.len() <= INLINE_CAP, "only for inline-able strings");
740+
741+
let mut slice = s.as_bytes();
742+
let mut out = [0u8; INLINE_CAP];
743+
let mut out_slice = &mut out[..slice.len()];
744+
let mut is_ascii = [false; N];
745+
746+
while slice.len() >= N {
747+
// SAFETY: checked in loop condition
748+
let chunk = unsafe { slice.get_unchecked(..N) };
749+
// SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets
750+
let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) };
751+
752+
for j in 0..N {
753+
is_ascii[j] = chunk[j] <= 127;
754+
}
755+
756+
// Auto-vectorization for this check is a bit fragile, sum and comparing against the chunk
757+
// size gives the best result, specifically a pmovmsk instruction on x86.
758+
// See https://github.com/llvm/llvm-project/issues/96395 for why llvm currently does not
759+
// currently recognize other similar idioms.
760+
if is_ascii.iter().map(|x| *x as u8).sum::<u8>() as usize != N {
761+
break;
762+
}
763+
764+
for j in 0..N {
765+
out_chunk[j] = convert(&chunk[j]);
766+
}
767+
768+
slice = unsafe { slice.get_unchecked(N..) };
769+
out_slice = unsafe { out_slice.get_unchecked_mut(N..) };
770+
}
771+
772+
// handle the remainder as individual bytes
773+
while !slice.is_empty() {
774+
let byte = slice[0];
775+
if byte > 127 {
776+
break;
777+
}
778+
// SAFETY: out_slice has at least same length as input slice
779+
unsafe {
780+
*out_slice.get_unchecked_mut(0) = convert(&byte);
781+
}
782+
slice = unsafe { slice.get_unchecked(1..) };
783+
out_slice = unsafe { out_slice.get_unchecked_mut(1..) };
784+
}
785+
786+
unsafe {
787+
// SAFETY: we know this is a valid char boundary
788+
// since we only skipped over leading ascii bytes
789+
let rest = core::str::from_utf8_unchecked(slice);
790+
(out, rest)
791+
}
792+
}
793+
702794
impl<T> ToSmolStr for T
703795
where
704796
T: fmt::Display + ?Sized,

0 commit comments

Comments
 (0)