Skip to content

Commit

Permalink
Adding vectorized implementations of Log to Vector64/128/256/512 (#96913
Browse files Browse the repository at this point in the history
)

* Adding vectorized implementations of Log to Vector64/128/256/512

* Accelerate TensorPrimitives.Log for double

* Ensure the ref assembly is updated to include the new Log method

* Fix the variance for one of the Log2 tests to account for the scalar fallback
  • Loading branch information
tannergooding authored Jan 17, 2024
1 parent 05fe3e0 commit 06005da
Show file tree
Hide file tree
Showing 12 changed files with 990 additions and 21 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1781,6 +1781,38 @@ internal static Vector128<ushort> LoadUnsafe(ref char source) =>
internal static Vector128<ushort> LoadUnsafe(ref char source, nuint elementOffset) =>
LoadUnsafe(ref Unsafe.As<char, ushort>(ref source), elementOffset);

/// <inheritdoc cref="Vector64.Log(Vector64{double})" />
public static Vector128<double> Log(Vector128<double> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.LogDouble<Vector128<double>, Vector128<long>, Vector128<ulong>>(vector);
}
else
{
return Create(
Vector64.Log(vector._lower),
Vector64.Log(vector._upper)
);
}
}

/// <inheritdoc cref="Vector64.Log(Vector64{float})" />
public static Vector128<float> Log(Vector128<float> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.LogSingle<Vector128<float>, Vector128<int>, Vector128<uint>>(vector);
}
else
{
return Create(
Vector64.Log(vector._lower),
Vector64.Log(vector._upper)
);
}
}

/// <inheritdoc cref="Vector64.Log2(Vector64{double})" />
public static Vector128<double> Log2(Vector128<double> vector)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,38 @@ internal static Vector256<ushort> LoadUnsafe(ref char source) =>
internal static Vector256<ushort> LoadUnsafe(ref char source, nuint elementOffset) =>
LoadUnsafe(ref Unsafe.As<char, ushort>(ref source), elementOffset);

/// <inheritdoc cref="Vector128.Log(Vector128{double})" />
public static Vector256<double> Log(Vector256<double> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.LogDouble<Vector256<double>, Vector256<long>, Vector256<ulong>>(vector);
}
else
{
return Create(
Vector128.Log(vector._lower),
Vector128.Log(vector._upper)
);
}
}

/// <inheritdoc cref="Vector128.Log(Vector128{float})" />
public static Vector256<float> Log(Vector256<float> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.LogSingle<Vector256<float>, Vector256<int>, Vector256<uint>>(vector);
}
else
{
return Create(
Vector128.Log(vector._lower),
Vector128.Log(vector._upper)
);
}
}

/// <inheritdoc cref="Vector128.Log2(Vector128{double})" />
public static Vector256<double> Log2(Vector256<double> vector)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1806,6 +1806,38 @@ internal static Vector512<ushort> LoadUnsafe(ref char source) =>
internal static Vector512<ushort> LoadUnsafe(ref char source, nuint elementOffset) =>
LoadUnsafe(ref Unsafe.As<char, ushort>(ref source), elementOffset);

/// <inheritdoc cref="Vector256.Log(Vector256{double})" />
public static Vector512<double> Log(Vector512<double> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.LogDouble<Vector512<double>, Vector512<long>, Vector512<ulong>>(vector);
}
else
{
return Create(
Vector256.Log(vector._lower),
Vector256.Log(vector._upper)
);
}
}

/// <inheritdoc cref="Vector256.Log(Vector256{float})" />
public static Vector512<float> Log(Vector512<float> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.LogSingle<Vector512<float>, Vector512<int>, Vector512<uint>>(vector);
}
else
{
return Create(
Vector256.Log(vector._lower),
Vector256.Log(vector._upper)
);
}
}

/// <inheritdoc cref="Vector256.Log2(Vector256{double})" />
public static Vector512<double> Log2(Vector512<double> vector)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,50 @@ public static Vector64<T> LoadUnsafe<T>(ref readonly T source, nuint elementOffs
return Unsafe.ReadUnaligned<Vector64<T>>(in address);
}

internal static Vector64<T> Log<T>(Vector64<T> vector)
where T : ILogarithmicFunctions<T>
{
Unsafe.SkipInit(out Vector64<T> result);

for (int index = 0; index < Vector64<T>.Count; index++)
{
T value = T.Log(vector.GetElement(index));
result.SetElementUnsafe(index, value);
}

return result;
}

/// <summary>Computes the log of each element in a vector.</summary>
/// <param name="vector">The vector that will have its log computed.</param>
/// <returns>A vector whose elements are the log of the elements in <paramref name="vector" />.</returns>
public static Vector64<double> Log(Vector64<double> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.LogDouble<Vector64<double>, Vector64<long>, Vector64<ulong>>(vector);
}
else
{
return Log<double>(vector);
}
}

/// <summary>Computes the log of each element in a vector.</summary>
/// <param name="vector">The vector that will have its log computed.</param>
/// <returns>A vector whose elements are the log of the elements in <paramref name="vector" />.</returns>
public static Vector64<float> Log(Vector64<float> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.LogSingle<Vector64<float>, Vector64<int>, Vector64<uint>>(vector);
}
else
{
return Log<float>(vector);
}
}

internal static Vector64<T> Log2<T>(Vector64<T> vector)
where T : ILogarithmicFunctions<T>
{
Expand Down
Loading

0 comments on commit 06005da

Please sign in to comment.