Skip to content

Commit

Permalink
Rollup merge of rust-lang#32699 - bluss:slice-memcmp, r=alexcrichton
Browse files Browse the repository at this point in the history
Specialize equality for [T] and comparison for [u8] to use memcmp when possible

Specialize equality for [T] and comparison for [u8] to use memcmp when possible

Where T is a type that can be compared for equality bytewise, we can use
memcmp. We can also use memcmp for PartialOrd, Ord for [u8].

Use specialization to call memcmp in PartialEq for slices for certain element types. This PR does not change the user visible API since the implementation uses an intermediate trait. See commit messages for more information.

The memcmp signature was changed from `*const i8` to `*const u8` which is in line with how the memcmp function is defined in C (taking const void * arguments, interpreting the values as unsigned bytes for purposes of the comparison).
  • Loading branch information
Manishearth committed Apr 7, 2016
2 parents a243f41 + a6c27be commit 925822a
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 44 deletions.
42 changes: 36 additions & 6 deletions src/libcollectionstest/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,18 +574,48 @@ fn test_slice_2() {
assert_eq!(v[1], 3);
}

macro_rules! assert_order {
(Greater, $a:expr, $b:expr) => {
assert_eq!($a.cmp($b), Greater);
assert!($a > $b);
};
(Less, $a:expr, $b:expr) => {
assert_eq!($a.cmp($b), Less);
assert!($a < $b);
};
(Equal, $a:expr, $b:expr) => {
assert_eq!($a.cmp($b), Equal);
assert_eq!($a, $b);
}
}

#[test]
fn test_total_ord_u8() {
let c = &[1u8, 2, 3];
assert_order!(Greater, &[1u8, 2, 3, 4][..], &c[..]);
let c = &[1u8, 2, 3, 4];
assert_order!(Less, &[1u8, 2, 3][..], &c[..]);
let c = &[1u8, 2, 3, 6];
assert_order!(Equal, &[1u8, 2, 3, 6][..], &c[..]);
let c = &[1u8, 2, 3, 4, 5, 6];
assert_order!(Less, &[1u8, 2, 3, 4, 5, 5, 5, 5][..], &c[..]);
let c = &[1u8, 2, 3, 4];
assert_order!(Greater, &[2u8, 2][..], &c[..]);
}


#[test]
fn test_total_ord() {
fn test_total_ord_i32() {
let c = &[1, 2, 3];
[1, 2, 3, 4][..].cmp(c) == Greater;
assert_order!(Greater, &[1, 2, 3, 4][..], &c[..]);
let c = &[1, 2, 3, 4];
[1, 2, 3][..].cmp(c) == Less;
assert_order!(Less, &[1, 2, 3][..], &c[..]);
let c = &[1, 2, 3, 6];
[1, 2, 3, 4][..].cmp(c) == Equal;
assert_order!(Equal, &[1, 2, 3, 6][..], &c[..]);
let c = &[1, 2, 3, 4, 5, 6];
[1, 2, 3, 4, 5, 5, 5, 5][..].cmp(c) == Less;
assert_order!(Less, &[1, 2, 3, 4, 5, 5, 5, 5][..], &c[..]);
let c = &[1, 2, 3, 4];
[2, 2][..].cmp(c) == Greater;
assert_order!(Greater, &[2, 2][..], &c[..]);
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions src/libcore/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
#![feature(unwind_attributes)]
#![feature(repr_simd, platform_intrinsics)]
#![feature(rustc_attrs)]
#![feature(specialization)]
#![feature(staged_api)]
#![feature(unboxed_closures)]
#![feature(question_mark)]
Expand Down
155 changes: 139 additions & 16 deletions src/libcore/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1630,12 +1630,60 @@ pub unsafe fn from_raw_parts_mut<'a, T>(p: *mut T, len: usize) -> &'a mut [T] {
}

//
// Boilerplate traits
// Comparison traits
//

extern {
/// Call implementation provided memcmp
///
/// Interprets the data as u8.
///
/// Return 0 for equal, < 0 for less than and > 0 for greater
/// than.
// FIXME(#32610): Return type should be c_int
fn memcmp(s1: *const u8, s2: *const u8, n: usize) -> i32;
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {
fn eq(&self, other: &[B]) -> bool {
SlicePartialEq::equal(self, other)
}

fn ne(&self, other: &[B]) -> bool {
SlicePartialEq::not_equal(self, other)
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: Eq> Eq for [T] {}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: Ord> Ord for [T] {
fn cmp(&self, other: &[T]) -> Ordering {
SliceOrd::compare(self, other)
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: PartialOrd> PartialOrd for [T] {
fn partial_cmp(&self, other: &[T]) -> Option<Ordering> {
SlicePartialOrd::partial_compare(self, other)
}
}

#[doc(hidden)]
// intermediate trait for specialization of slice's PartialEq
trait SlicePartialEq<B> {
fn equal(&self, other: &[B]) -> bool;
fn not_equal(&self, other: &[B]) -> bool;
}

// Generic slice equality
impl<A, B> SlicePartialEq<B> for [A]
where A: PartialEq<B>
{
default fn equal(&self, other: &[B]) -> bool {
if self.len() != other.len() {
return false;
}
Expand All @@ -1648,7 +1696,8 @@ impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {

true
}
fn ne(&self, other: &[B]) -> bool {

default fn not_equal(&self, other: &[B]) -> bool {
if self.len() != other.len() {
return true;
}
Expand All @@ -1663,12 +1712,36 @@ impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: Eq> Eq for [T] {}
// Use memcmp for bytewise equality when the types allow
impl<A> SlicePartialEq<A> for [A]
where A: PartialEq<A> + BytewiseEquality
{
fn equal(&self, other: &[A]) -> bool {
if self.len() != other.len() {
return false;
}
unsafe {
let size = mem::size_of_val(self);
memcmp(self.as_ptr() as *const u8,
other.as_ptr() as *const u8, size) == 0
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: Ord> Ord for [T] {
fn cmp(&self, other: &[T]) -> Ordering {
fn not_equal(&self, other: &[A]) -> bool {
!self.equal(other)
}
}

#[doc(hidden)]
// intermediate trait for specialization of slice's PartialOrd
trait SlicePartialOrd<B> {
fn partial_compare(&self, other: &[B]) -> Option<Ordering>;
}

impl<A> SlicePartialOrd<A> for [A]
where A: PartialOrd
{
default fn partial_compare(&self, other: &[A]) -> Option<Ordering> {
let l = cmp::min(self.len(), other.len());

// Slice to the loop iteration range to enable bound check
Expand All @@ -1677,19 +1750,33 @@ impl<T: Ord> Ord for [T] {
let rhs = &other[..l];

for i in 0..l {
match lhs[i].cmp(&rhs[i]) {
Ordering::Equal => (),
match lhs[i].partial_cmp(&rhs[i]) {
Some(Ordering::Equal) => (),
non_eq => return non_eq,
}
}

self.len().cmp(&other.len())
self.len().partial_cmp(&other.len())
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: PartialOrd> PartialOrd for [T] {
fn partial_cmp(&self, other: &[T]) -> Option<Ordering> {
impl SlicePartialOrd<u8> for [u8] {
#[inline]
fn partial_compare(&self, other: &[u8]) -> Option<Ordering> {
Some(SliceOrd::compare(self, other))
}
}

#[doc(hidden)]
// intermediate trait for specialization of slice's Ord
trait SliceOrd<B> {
fn compare(&self, other: &[B]) -> Ordering;
}

impl<A> SliceOrd<A> for [A]
where A: Ord
{
default fn compare(&self, other: &[A]) -> Ordering {
let l = cmp::min(self.len(), other.len());

// Slice to the loop iteration range to enable bound check
Expand All @@ -1698,12 +1785,48 @@ impl<T: PartialOrd> PartialOrd for [T] {
let rhs = &other[..l];

for i in 0..l {
match lhs[i].partial_cmp(&rhs[i]) {
Some(Ordering::Equal) => (),
match lhs[i].cmp(&rhs[i]) {
Ordering::Equal => (),
non_eq => return non_eq,
}
}

self.len().partial_cmp(&other.len())
self.len().cmp(&other.len())
}
}

// memcmp compares a sequence of unsigned bytes lexicographically.
// this matches the order we want for [u8], but no others (not even [i8]).
impl SliceOrd<u8> for [u8] {
#[inline]
fn compare(&self, other: &[u8]) -> Ordering {
let order = unsafe {
memcmp(self.as_ptr(), other.as_ptr(),
cmp::min(self.len(), other.len()))
};
if order == 0 {
self.len().cmp(&other.len())
} else if order < 0 {
Less
} else {
Greater
}
}
}

#[doc(hidden)]
/// Trait implemented for types that can be compared for equality using
/// their bytewise representation
trait BytewiseEquality { }

macro_rules! impl_marker_for {
($traitname:ident, $($ty:ty)*) => {
$(
impl $traitname for $ty { }
)*
}
}

impl_marker_for!(BytewiseEquality,
u8 i8 u16 i16 u32 i32 u64 i64 usize isize char bool);

25 changes: 3 additions & 22 deletions src/libcore/str/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1150,16 +1150,7 @@ Section: Comparing strings
#[lang = "str_eq"]
#[inline]
fn eq_slice(a: &str, b: &str) -> bool {
a.len() == b.len() && unsafe { cmp_slice(a, b, a.len()) == 0 }
}

/// Bytewise slice comparison.
/// NOTE: This uses the system's memcmp, which is currently dramatically
/// faster than comparing each byte in a loop.
#[inline]
unsafe fn cmp_slice(a: &str, b: &str, len: usize) -> i32 {
extern { fn memcmp(s1: *const i8, s2: *const i8, n: usize) -> i32; }
memcmp(a.as_ptr() as *const i8, b.as_ptr() as *const i8, len)
a.as_bytes() == b.as_bytes()
}

/*
Expand Down Expand Up @@ -1328,8 +1319,7 @@ Section: Trait implementations
*/

mod traits {
use cmp::{self, Ordering, Ord, PartialEq, PartialOrd, Eq};
use cmp::Ordering::{Less, Greater};
use cmp::{Ord, Ordering, PartialEq, PartialOrd, Eq};
use iter::Iterator;
use option::Option;
use option::Option::Some;
Expand All @@ -1340,16 +1330,7 @@ mod traits {
impl Ord for str {
#[inline]
fn cmp(&self, other: &str) -> Ordering {
let cmp = unsafe {
super::cmp_slice(self, other, cmp::min(self.len(), other.len()))
};
if cmp == 0 {
self.len().cmp(&other.len())
} else if cmp < 0 {
Less
} else {
Greater
}
self.as_bytes().cmp(other.as_bytes())
}
}

Expand Down

0 comments on commit 925822a

Please sign in to comment.