Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize TensorPrimitives.ConvertToHalf #92715

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,301 @@ public static void ConvertToHalf(ReadOnlySpan<float> source, Span<Half> destinat
ThrowHelper.ThrowArgument_DestinationTooShort();
}

for (int i = 0; i < source.Length; i++)
ref float sourceRef = ref MemoryMarshal.GetReference(source);
ref ushort destinationRef = ref Unsafe.As<Half, ushort>(ref MemoryMarshal.GetReference(destination));
int i = 0, twoVectorsFromEnd;

#if NET8_0_OR_GREATER
if (Vector512.IsHardwareAccelerated)
{
destination[i] = (Half)source[i];
twoVectorsFromEnd = source.Length - (Vector512<float>.Count * 2);
if (i <= twoVectorsFromEnd)
{
// Loop handling two input vectors / one output vector at a time.
do
{
Vector512<uint> lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
Vector512<uint> upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512<float>.Count)));
Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);

i += Vector512<float>.Count * 2;
}
while (i <= twoVectorsFromEnd);

// Handle any remaining elements with final vectors.
if (i != source.Length)
{
i = source.Length - (Vector512<float>.Count * 2);

Vector512<uint> lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
Vector512<uint> upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512<float>.Count)));
Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);
}

return;
}
}
#endif

if (Vector256.IsHardwareAccelerated)
{
twoVectorsFromEnd = source.Length - (Vector256<float>.Count * 2);
if (i <= twoVectorsFromEnd)
{
// Loop handling two input vectors / one output vector at a time.
do
{
Vector256<uint> lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
Vector256<uint> upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256<float>.Count)));
Vector256<ushort> halfs = Vector256.Narrow(lower, upper);
halfs.StoreUnsafe(ref destinationRef, (uint)i);

i += Vector256<float>.Count * 2;
}
while (i <= twoVectorsFromEnd);

// Handle any remaining elements with final vectors.
if (i != source.Length)
{
i = source.Length - (Vector256<float>.Count * 2);

Vector256<uint> lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
Vector256<uint> upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256<float>.Count)));
Vector256.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);
}

return;
}
}

if (Vector128.IsHardwareAccelerated)
{
twoVectorsFromEnd = source.Length - (Vector128<float>.Count * 2);
if (i <= twoVectorsFromEnd)
{
// Loop handling two input vectors / one output vector at a time.
do
{
Vector128<uint> lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
Vector128<uint> upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128<float>.Count)));
Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);

i += Vector128<float>.Count * 2;
}
while (i <= twoVectorsFromEnd);

// Handle any remaining elements with final vectors.
if (i != source.Length)
{
i = source.Length - (Vector128<float>.Count * 2);

Vector128<uint> lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
Vector128<uint> upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128<float>.Count)));
Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);
}

return;
}
}

while (i < source.Length)
{
Unsafe.Add(ref destinationRef, i) = BitConverter.HalfToUInt16Bits((Half)Unsafe.Add(ref sourceRef, i));
i++;
}

// This implements a vectorized version of the `explicit operator Half(float value) operator`.
// See detailed description of the algorithm used here:
// https://github.com/dotnet/runtime/blob/ca8d6f0420096831766ec11c7d400e4f7ccc7a34/src/libraries/System.Private.CoreLib/src/System/Half.cs#L606-L714
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
// The cast operator converts a float to a Half represented as a UInt32, then narrows to a UInt16, and reinterpret casts to Half.
// This does the same, with an input VectorXx<float> and an output VectorXx<uint>.
// Loop handling two input vectors at a time; each input float is double the size of each output Half,
// so we need two vectors of floats to produce one vector of Halfs. Half isn't supported in VectorXx<T>,
// so we convert the VectorXx<float> to a VectorXx<uint>, and the caller then uses this twice, narrows the combination
// into a VectorXx<ushort>, and then saves that out to the destination `ref Half` reinterpreted as `ref ushort`.

#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
const uint MinExp = 0x3880_0000u; // Minimum exponent for rounding
const uint Exponent126 = 0x3f00_0000u; // Exponent displacement #1
const uint SingleBiasedExponentMask = 0x7F80_0000; // float.BiasedExponentMask; // Exponent mask
const uint Exponent13 = 0x0680_0000u; // Exponent displacement #2
const float MaxHalfValueBelowInfinity = 65520.0f; // Maximum value that is not Infinity in Half
const uint ExponentMask = 0x7C00; // Mask for exponent bits in Half
const uint SingleSignMask = 0x8000_0000u; // float.SignMask; // Mask for sign bit in float
#pragma warning restore IDE0059

static Vector128<uint> SingleToHalfAsWidenedUInt32_Vector128(Vector128<float> value)
{
Vector128<uint> bitValue = value.AsUInt32();

// Extract sign bit
Vector128<uint> sign = Vector128.ShiftRightLogical(bitValue & Vector128.Create(SingleSignMask), 16);

// Detecting NaN (0u if value is NaN; otherwise, ~0u)
Vector128<uint> realMask = Vector128.Equals(value, value).AsUInt32();

// Clear sign bit
value = Vector128.Abs(value);

// Rectify values that are Infinity in Half.
value = Vector128.Min(Vector128.Create(MaxHalfValueBelowInfinity), value);

// Rectify lower exponent
Vector128<uint> exponentOffset0 = Vector128.Max(value, Vector128.Create(MinExp).AsSingle()).AsUInt32();

// Extract exponent
exponentOffset0 &= Vector128.Create(SingleBiasedExponentMask);

// Add exponent by 13
exponentOffset0 += Vector128.Create(Exponent13);

// Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction)
value += exponentOffset0.AsSingle();
bitValue = value.AsUInt32();

// Only exponent bits will be modified if NaN
Vector128<uint> maskedHalfExponentForNaN = ~realMask & Vector128.Create(ExponentMask);

// Subtract exponent by 126
bitValue -= Vector128.Create(Exponent126);

// Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part.
Vector128<uint> newExponent = Vector128.ShiftRightLogical(bitValue, 13);

// Clear the fraction parts if the value was NaN.
bitValue &= realMask;

// Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow.
bitValue += newExponent;

// Clear exponents if value is NaN
bitValue &= ~maskedHalfExponentForNaN;

// Merge sign bit with possible NaN exponent
Vector128<uint> signAndMaskedExponent = maskedHalfExponentForNaN | sign;

// Merge sign bit and possible NaN exponent
bitValue |= signAndMaskedExponent;

// The final result
return bitValue;
}

static Vector256<uint> SingleToHalfAsWidenedUInt32_Vector256(Vector256<float> value)
{
Vector256<uint> bitValue = value.AsUInt32();

// Extract sign bit
Vector256<uint> sign = Vector256.ShiftRightLogical(bitValue & Vector256.Create(SingleSignMask), 16);

// Detecting NaN (0u if value is NaN; otherwise, ~0u)
Vector256<uint> realMask = Vector256.Equals(value, value).AsUInt32();

// Clear sign bit
value = Vector256.Abs(value);

// Rectify values that are Infinity in Half.
value = Vector256.Min(Vector256.Create(MaxHalfValueBelowInfinity), value);

// Rectify lower exponent
Vector256<uint> exponentOffset0 = Vector256.Max(value, Vector256.Create(MinExp).AsSingle()).AsUInt32();

// Extract exponent
exponentOffset0 &= Vector256.Create(SingleBiasedExponentMask);

// Add exponent by 13
exponentOffset0 += Vector256.Create(Exponent13);

// Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction)
value += exponentOffset0.AsSingle();
bitValue = value.AsUInt32();

// Only exponent bits will be modified if NaN
Vector256<uint> maskedHalfExponentForNaN = ~realMask & Vector256.Create(ExponentMask);

// Subtract exponent by 126
bitValue -= Vector256.Create(Exponent126);

// Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part.
Vector256<uint> newExponent = Vector256.ShiftRightLogical(bitValue, 13);

// Clear the fraction parts if the value was NaN.
bitValue &= realMask;

// Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow.
bitValue += newExponent;

// Clear exponents if value is NaN
bitValue &= ~maskedHalfExponentForNaN;

// Merge sign bit with possible NaN exponent
Vector256<uint> signAndMaskedExponent = maskedHalfExponentForNaN | sign;

// Merge sign bit and possible NaN exponent
bitValue |= signAndMaskedExponent;

// The final result
return bitValue;
}

#if NET8_0_OR_GREATER
static Vector512<uint> SingleToHalfAsWidenedUInt32_Vector512(Vector512<float> value)
{
Vector512<uint> bitValue = value.AsUInt32();

// Extract sign bit
Vector512<uint> sign = Vector512.ShiftRightLogical(bitValue & Vector512.Create(SingleSignMask), 16);

// Detecting NaN (0u if value is NaN; otherwise, ~0u)
Vector512<uint> realMask = Vector512.Equals(value, value).AsUInt32();

// Clear sign bit
value = Vector512.Abs(value);

// Rectify values that are Infinity in Half.
value = Vector512.Min(Vector512.Create(MaxHalfValueBelowInfinity), value);

// Rectify lower exponent
Vector512<uint> exponentOffset0 = Vector512.Max(value, Vector512.Create(MinExp).AsSingle()).AsUInt32();

// Extract exponent
exponentOffset0 &= Vector512.Create(SingleBiasedExponentMask);

// Add exponent by 13
exponentOffset0 += Vector512.Create(Exponent13);

// Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction)
value += exponentOffset0.AsSingle();
bitValue = value.AsUInt32();

// Only exponent bits will be modified if NaN
Vector512<uint> maskedHalfExponentForNaN = ~realMask & Vector512.Create(ExponentMask);

// Subtract exponent by 126
bitValue -= Vector512.Create(Exponent126);

// Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part.
Vector512<uint> newExponent = Vector512.ShiftRightLogical(bitValue, 13);

// Clear the fraction parts if the value was NaN.
bitValue &= realMask;

// Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow.
bitValue += newExponent;

// Clear exponents if value is NaN
bitValue &= ~maskedHalfExponentForNaN;

// Merge sign bit with possible NaN exponent
Vector512<uint> signAndMaskedExponent = maskedHalfExponentForNaN | sign;

// Merge sign bit and possible NaN exponent
bitValue |= signAndMaskedExponent;

// The final result
return bitValue;
}
#endif
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ public static void ConvertToHalf(int tensorLength)
using BoundedMemory<float> source = CreateAndFillTensor(tensorLength);
foreach (int destLength in new[] { source.Length, source.Length + 1 })
{
Half[] destination = new Half[destLength];
using BoundedMemory<Half> destination = BoundedMemory.Allocate<Half>(destLength);
destination.Span.Fill(Half.Zero);

TensorPrimitives.ConvertToHalf(source, destination);

Expand All @@ -35,6 +36,28 @@ public static void ConvertToHalf(int tensorLength)
}
}

[Theory]
[MemberData(nameof(TensorLengths))]
public static void ConvertToHalf_SpecialValues(int tensorLength)
{
using BoundedMemory<float> source = CreateAndFillTensor(tensorLength);
using BoundedMemory<Half> destination = BoundedMemory.Allocate<Half>(tensorLength);

// NaN, infinities, and 0s
source[s_random.Next(source.Length)] = float.NaN;
source[s_random.Next(source.Length)] = float.PositiveInfinity;
source[s_random.Next(source.Length)] = float.NegativeInfinity;
source[s_random.Next(source.Length)] = 0;
source[s_random.Next(source.Length)] = float.NegativeZero;

TensorPrimitives.ConvertToHalf(source, destination);

for (int i = 0; i < source.Length; i++)
{
Assert.Equal((Half)source[i], destination[i]);
}
}

[Theory]
[MemberData(nameof(TensorLengths))]
public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength)
Expand All @@ -51,7 +74,7 @@ public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength)
[MemberData(nameof(TensorLengthsIncluding0))]
public static void ConvertToSingle(int tensorLength)
{
Half[] source = new Half[tensorLength];
using BoundedMemory<Half> source = BoundedMemory.Allocate<Half>(tensorLength);
for (int i = 0; i < source.Length; i++)
{
source[i] = (Half)s_random.NextSingle();
Expand All @@ -78,6 +101,32 @@ public static void ConvertToSingle(int tensorLength)
}
}
}
[Theory]
[MemberData(nameof(TensorLengths))]
public static void ConvertToSingle_SpecialValues(int tensorLength)
{
using BoundedMemory<Half> source = BoundedMemory.Allocate<Half>(tensorLength);
for (int i = 0; i < source.Length; i++)
{
source[i] = (Half)s_random.NextSingle();
}

using BoundedMemory<float> destination = CreateTensor(tensorLength);

// NaN, infinities, and 0s
source[s_random.Next(source.Length)] = Half.NaN;
source[s_random.Next(source.Length)] = Half.PositiveInfinity;
source[s_random.Next(source.Length)] = Half.NegativeInfinity;
source[s_random.Next(source.Length)] = Half.Zero;
source[s_random.Next(source.Length)] = Half.NegativeZero;

TensorPrimitives.ConvertToSingle(source, destination);

for (int i = 0; i < source.Length; i++)
{
Assert.Equal((float)source[i], destination[i]);
}
}

[Theory]
[MemberData(nameof(TensorLengths))]
Expand Down
Loading