Skip to content

Commit 2e3e88c

Browse files
committed
[stdlib] Optimize Set.intersection(_:)
Use a temporary bitset to speed up the `Sequence` variant by roughly a factor of ~4-6, and the set/set variant by a factor of ~1-4, depending on the ratio of overlapping elements.
1 parent db80c78 commit 2e3e88c

File tree

3 files changed

+87
-25
lines changed

3 files changed

+87
-25
lines changed

stdlib/public/core/NativeSet.swift

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ extension _NativeSet {
674674
count -= 1
675675
if count == 0 { break }
676676
}
677+
_internalInvariant(result.count == bitset.count)
677678
return result
678679
}
679680

@@ -716,4 +717,49 @@ extension _NativeSet {
716717
return extractSubset(using: bitset, count: count)
717718
}
718719
}
720+
721+
@_alwaysEmitIntoClient
722+
internal __consuming func intersection(
723+
_ other: _NativeSet<Element>
724+
) -> _NativeSet<Element> {
725+
// Prefer to iterate over the smaller set. However, we must be careful to
726+
// only include elements from `self`, not `other`.
727+
guard self.count <= other.count else {
728+
return genericIntersection(other)
729+
}
730+
// Rather than directly creating a new set, mark common elements in a bitset
731+
// first. This minimizes hashing, and ensures that we'll have an exact count
732+
// for the result set, preventing rehashings during insertions.
733+
return _UnsafeBitset.withTemporaryBitset(capacity: bucketCount) { bitset in
734+
var count = 0
735+
for bucket in hashTable {
736+
if other.find(uncheckedElement(at: bucket)).found {
737+
bitset.uncheckedInsert(bucket.offset)
738+
count += 1
739+
}
740+
}
741+
return extractSubset(using: bitset, count: count)
742+
}
743+
}
744+
745+
@_alwaysEmitIntoClient
746+
internal __consuming func genericIntersection<S: Sequence>(
747+
_ other: S
748+
) -> _NativeSet<Element>
749+
where S.Element == Element {
750+
// Rather than directly creating a new set, mark common elements in a bitset
751+
// first. This minimizes hashing, and ensures that we'll have an exact count
752+
// for the result set, preventing rehashings during insertions.
753+
_UnsafeBitset.withTemporaryBitset(capacity: bucketCount) { bitset in
754+
var count = 0
755+
for element in other {
756+
let (bucket, found) = find(element)
757+
if found {
758+
bitset.uncheckedInsert(bucket.offset)
759+
count += 1
760+
}
761+
}
762+
return extractSubset(using: bitset, count: count)
763+
}
764+
}
719765
}

stdlib/public/core/Set.swift

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -949,8 +949,10 @@ extension Set: SetAlgebra {
949949
@inlinable
950950
public __consuming func intersection<S: Sequence>(_ other: S) -> Set<Element>
951951
where S.Element == Element {
952-
let otherSet = Set(other)
953-
return intersection(otherSet)
952+
if let other = other as? Set<Element> {
953+
return self.intersection(other)
954+
}
955+
return Set(_native: _variant.convertedToNative.genericIntersection(other))
954956
}
955957

956958
/// Removes the elements of the set that aren't also in the given sequence.
@@ -969,20 +971,10 @@ extension Set: SetAlgebra {
969971
@inlinable
970972
public mutating func formIntersection<S: Sequence>(_ other: S)
971973
where S.Element == Element {
972-
// Because `intersect` needs to both modify and iterate over
973-
// the left-hand side, the index may become invalidated during
974-
// traversal so an intermediate set must be created.
975-
//
976-
// FIXME(performance): perform this operation at a lower level
977-
// to avoid invalidating the index and avoiding a copy.
978-
let result = self.intersection(other)
979-
980-
// The result can only have fewer or the same number of elements.
981-
// If no elements were removed, don't perform a reassignment
982-
// as this may cause an unnecessary uniquing COW.
983-
if result.count != count {
984-
self = result
985-
}
974+
// FIXME: This discards storage reserved with reserveCapacity.
975+
// FIXME: Depending on the ratio of elements kept in the result, it may be
976+
// faster to do the removals in place, in bulk.
977+
self = self.intersection(other)
986978
}
987979

988980
/// Returns a new set with the elements that are either in this set or in the
@@ -1233,15 +1225,7 @@ extension Set {
12331225
/// - Returns: A new set.
12341226
@inlinable
12351227
public __consuming func intersection(_ other: Set<Element>) -> Set<Element> {
1236-
var newSet = Set<Element>()
1237-
let (smaller, larger) =
1238-
count < other.count ? (self, other) : (other, self)
1239-
for member in smaller {
1240-
if larger.contains(member) {
1241-
newSet.insert(member)
1242-
}
1243-
}
1244-
return newSet
1228+
Set(_native: _variant.intersection(other))
12451229
}
12461230

12471231
/// Removes the elements of the set that are also in the given sequence and

stdlib/public/core/SetVariant.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,4 +396,36 @@ extension Set._Variant {
396396
#endif
397397
return try asNative.filter(isIncluded)
398398
}
399+
400+
@_alwaysEmitIntoClient
401+
internal __consuming func intersection(
402+
_ other: Set<Element>
403+
) -> _NativeSet<Element> {
404+
#if _runtime(_ObjC)
405+
switch (self.isNative, other._variant.isNative) {
406+
case (true, true):
407+
return asNative.intersection(other._variant.asNative)
408+
case (true, false):
409+
return asNative.genericIntersection(other)
410+
case (false, false):
411+
return _NativeSet(asCocoa).genericIntersection(other)
412+
case (false, true):
413+
// Note: It is tempting to implement this as `that.intersection(this)`,
414+
// but intersection isn't symmetric -- the result should only contain
415+
// elements from `self`.
416+
let that = other._variant.asNative
417+
var result = _NativeSet<Element>()
418+
for cocoaElement in asCocoa {
419+
let nativeElement = _forceBridgeFromObjectiveC(
420+
cocoaElement, Element.self)
421+
if that.contains(nativeElement) {
422+
result.insertNew(nativeElement, isUnique: true)
423+
}
424+
}
425+
return result
426+
}
427+
#else
428+
return asNative.intersection(other._variant.asNative)
429+
#endif
430+
}
399431
}

0 commit comments

Comments
 (0)