Skip to content

Light up String.Manipulation APIs with Vector512 codepath #92579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
241 changes: 182 additions & 59 deletions src/libraries/System.Private.CoreLib/src/System/String.Manipulation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,21 @@ public string Replace(char oldChar, char newChar)
// process the remaining elements vectorized too.
// Thus we adjust the pointers so that at least one full vector from the end can be processed.
nuint length = (uint)Length;
if (Vector128.IsHardwareAccelerated && length >= (uint)Vector128<ushort>.Count)
if (Vector512.IsHardwareAccelerated && length >= (uint)Vector512<ushort>.Count)
{
nuint adjust = (length - remainingLength) & ((uint)Vector512<ushort>.Count - 1);
pSrc = ref Unsafe.Subtract(ref pSrc, adjust);
pDst = ref Unsafe.Subtract(ref pDst, adjust);
remainingLength += adjust;
}
else if (Vector256.IsHardwareAccelerated && length >= (uint)Vector256<ushort>.Count)
{
nuint adjust = (length - remainingLength) & ((uint)Vector256<ushort>.Count - 1);
pSrc = ref Unsafe.Subtract(ref pSrc, adjust);
pDst = ref Unsafe.Subtract(ref pDst, adjust);
remainingLength += adjust;
}
else if (Vector128.IsHardwareAccelerated && length >= (uint)Vector128<ushort>.Count)
{
nuint adjust = (length - remainingLength) & ((uint)Vector128<ushort>.Count - 1);
pSrc = ref Unsafe.Subtract(ref pSrc, adjust);
Expand Down Expand Up @@ -1224,35 +1238,7 @@ public string Replace(string oldValue, string? newValue)
}

// Find all occurrences of the oldValue character.
char c = oldValue[0];
int i = 0;

if (PackedSpanHelpers.PackedIndexOfIsSupported && PackedSpanHelpers.CanUsePackedIndexOf(c))
{
while (true)
{
int pos = PackedSpanHelpers.IndexOf(ref Unsafe.Add(ref _firstChar, i), c, Length - i);
if (pos < 0)
{
break;
}
replacementIndices.Append(i + pos);
i += pos + 1;
}
}
else
{
while (true)
{
int pos = SpanHelpers.NonPackedIndexOfChar(ref Unsafe.Add(ref _firstChar, i), c, Length - i);
if (pos < 0)
{
break;
}
replacementIndices.Append(i + pos);
i += pos + 1;
}
}
MakeReplacementSearchVectorized(this, ref replacementIndices, oldValue[0]);
}
else
{
Expand Down Expand Up @@ -1285,6 +1271,91 @@ public string Replace(string oldValue, string? newValue)
return dst;
}

private static void MakeReplacementSearchVectorized(ReadOnlySpan<char> sourceSpan, ref ValueListBuilder<int> replacementIndices, char c)
{
nuint offset = 0;
nuint lengthToExamine = (uint)sourceSpan.Length;
ref char source = ref MemoryMarshal.GetReference(sourceSpan);

if (Vector512.IsHardwareAccelerated && sourceSpan.Length >= (uint)Vector512<ushort>.Count*2)
{
Vector512<ushort> v1 = Vector512.Create((ushort)c);
do
{
Vector512<ushort> vector = Vector512.LoadUnsafe(ref source, offset);

if (Vector512.EqualsAny(vector, v1))
{
// Skip every other bit
ulong mask = (Vector512.Equals(vector, v1)).AsByte().ExtractMostSignificantBits() & 0x5555555555555555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
replacementIndices.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

offset += (nuint)Vector512<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector512<ushort>.Count);
}
else if (Vector256.IsHardwareAccelerated && sourceSpan.Length >= (uint)Vector256<ushort>.Count*2)
{
Vector256<ushort> v1 = Vector256.Create((ushort)c);
do
{
Vector256<ushort> vector = Vector256.LoadUnsafe(ref source, offset);
Vector256<byte> cmp = (Vector256.Equals(vector, v1)).AsByte();

if (cmp != Vector256<byte>.Zero)
{
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x55555555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
replacementIndices.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

offset += (nuint)Vector256<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector256<ushort>.Count);
}
else if (Vector128.IsHardwareAccelerated && sourceSpan.Length >= (uint)Vector128<ushort>.Count*2)
{
Vector128<ushort> v1 = Vector128.Create((ushort)c);
do
{
Vector128<ushort> vector = Vector128.LoadUnsafe(ref source, offset);
Vector128<byte> cmp = (Vector128.Equals(vector, v1)).AsByte();

if (cmp != Vector128<byte>.Zero)
{
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x5555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
replacementIndices.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

offset += (nuint)Vector128<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector128<ushort>.Count);
}
while (offset < lengthToExamine)
{
char curr = Unsafe.Add(ref source, offset);
if (curr == c)
{
replacementIndices.Append((int)offset);
}
offset++;
}
}

private string ReplaceHelper(int oldValueLength, string newValue, ReadOnlySpan<int> indices)
{
Debug.Assert(indices.Length > 0);
Expand Down Expand Up @@ -1899,46 +1970,98 @@ internal static void MakeSeparatorListAny(ReadOnlySpan<char> source, ReadOnlySpa

private static void MakeSeparatorListVectorized(ReadOnlySpan<char> sourceSpan, ref ValueListBuilder<int> sepListBuilder, char c, char c2, char c3)
{
// Redundant test so we won't prejit remainder of this method
// on platforms where it is not supported
if (!Vector128.IsHardwareAccelerated)
Debug.Assert(sourceSpan.Length >= Vector128<ushort>.Count);
nuint lengthToExamine = (uint)sourceSpan.Length;
nuint offset = 0;
ref char source = ref MemoryMarshal.GetReference(sourceSpan);

if (Vector512.IsHardwareAccelerated && lengthToExamine >= (uint)Vector512<ushort>.Count*2)
{
throw new PlatformNotSupportedException();
}
Vector512<ushort> v1 = Vector512.Create((ushort)c);
Vector512<ushort> v2 = Vector512.Create((ushort)c2);
Vector512<ushort> v3 = Vector512.Create((ushort)c3);

Debug.Assert(sourceSpan.Length >= Vector128<ushort>.Count);
do
{
Vector512<ushort> vector = Vector512.LoadUnsafe(ref source, offset);
Vector512<ushort> v1Eq = Vector512.Equals(vector, v1);
Vector512<ushort> v2Eq = Vector512.Equals(vector, v2);
Vector512<ushort> v3Eq = Vector512.Equals(vector, v3);
Vector512<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();

nuint offset = 0;
nuint lengthToExamine = (uint)sourceSpan.Length;
if (cmp != Vector512<byte>.Zero)
{
// Skip every other bit
ulong mask = cmp.ExtractMostSignificantBits() & 0x5555555555555555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

ref char source = ref MemoryMarshal.GetReference(sourceSpan);
offset += (nuint)Vector512<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector512<ushort>.Count);
}
else if (Vector256.IsHardwareAccelerated && lengthToExamine >= (uint)Vector256<ushort>.Count*2)
{
Vector256<ushort> v1 = Vector256.Create((ushort)c);
Vector256<ushort> v2 = Vector256.Create((ushort)c2);
Vector256<ushort> v3 = Vector256.Create((ushort)c3);

do
{
Vector256<ushort> vector = Vector256.LoadUnsafe(ref source, offset);
Vector256<ushort> v1Eq = Vector256.Equals(vector, v1);
Vector256<ushort> v2Eq = Vector256.Equals(vector, v2);
Vector256<ushort> v3Eq = Vector256.Equals(vector, v3);
Vector256<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();

Vector128<ushort> v1 = Vector128.Create((ushort)c);
Vector128<ushort> v2 = Vector128.Create((ushort)c2);
Vector128<ushort> v3 = Vector128.Create((ushort)c3);
if (cmp != Vector256<byte>.Zero)
{
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x55555555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

do
offset += (nuint)Vector256<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector256<ushort>.Count);
}
else if (Vector128.IsHardwareAccelerated)
{
Vector128<ushort> vector = Vector128.LoadUnsafe(ref source, offset);
Vector128<ushort> v1Eq = Vector128.Equals(vector, v1);
Vector128<ushort> v2Eq = Vector128.Equals(vector, v2);
Vector128<ushort> v3Eq = Vector128.Equals(vector, v3);
Vector128<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();
Vector128<ushort> v1 = Vector128.Create((ushort)c);
Vector128<ushort> v2 = Vector128.Create((ushort)c2);
Vector128<ushort> v3 = Vector128.Create((ushort)c3);

if (cmp != Vector128<byte>.Zero)
do
{
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x5555;
do
Vector128<ushort> vector = Vector128.LoadUnsafe(ref source, offset);
Vector128<ushort> v1Eq = Vector128.Equals(vector, v1);
Vector128<ushort> v2Eq = Vector128.Equals(vector, v2);
Vector128<ushort> v3Eq = Vector128.Equals(vector, v3);
Vector128<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();

if (cmp != Vector128<byte>.Zero)
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x5555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

offset += (nuint)Vector128<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector128<ushort>.Count);
offset += (nuint)Vector128<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector128<ushort>.Count);
}

while (offset < lengthToExamine)
{
Expand Down