Skip to content

Vectorize Png encoder filters #1630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/ImageSharp/Common/Helpers/Numerics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -748,5 +748,82 @@ public static Vector256<float> Lerp(
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static float Lerp(float value1, float value2, float amount)
=> ((value2 - value1) * amount) + value1;

#if SUPPORTS_RUNTIME_INTRINSICS

/// <summary>
/// Accumulates 8-bit integers into <paramref name="accumulator"/> by
/// widening them to 32-bit integers and performing four additions.
/// </summary>
/// <remarks>
/// <code>byte(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)</code>
/// is widened and added onto <paramref name="accumulator"/> as such:
/// <code>
/// accumulator += i32(1, 2, 3, 4);
/// accumulator += i32(5, 6, 7, 8);
/// accumulator += i32(9, 10, 11, 12);
/// accumulator += i32(13, 14, 15, 16);
/// </code>
/// </remarks>
/// <param name="accumulator">The accumulator destination.</param>
/// <param name="values">The values to accumulate.</param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Accumulate(ref Vector<uint> accumulator, Vector<byte> values)
{
Vector.Widen(values, out Vector<ushort> shortLow, out Vector<ushort> shortHigh);

Vector.Widen(shortLow, out Vector<uint> intLow, out Vector<uint> intHigh);
accumulator += intLow;
accumulator += intHigh;

Vector.Widen(shortHigh, out intLow, out intHigh);
accumulator += intLow;
accumulator += intHigh;
}

/// <summary>
/// Reduces elements of the vector into one sum.
/// </summary>
/// <param name="accumulator">The accumulator to reduce.</param>
/// <returns>The sum of all elements.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int ReduceSum(Vector128<int> accumulator)
{
if (Ssse3.IsSupported)
{
Vector128<int> hadd = Ssse3.HorizontalAdd(accumulator, accumulator);
Vector128<int> swapped = Sse2.Shuffle(hadd, 0x1);
Vector128<int> tmp = Sse2.Add(hadd, swapped);

// Vector128<int>.ToScalar() isn't optimized pre-net5.0 https://github.com/dotnet/runtime/pull/37882
return Sse2.ConvertToInt32(tmp);
}
else
{
int sum = 0;
for (int i = 0; i < Vector128<int>.Count; i++)
{
sum += accumulator.GetElement(i);
}

return sum;
}
}

/// <summary>
/// Reduces even elements of the vector into one sum.
/// </summary>
/// <param name="accumulator">The accumulator to reduce.</param>
/// <returns>The sum of even elements.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int EvenReduceSum(Vector256<int> accumulator)
{
Vector128<int> vsum = Sse2.Add(accumulator.GetLower(), accumulator.GetUpper()); // add upper lane to lower lane
vsum = Sse2.Add(vsum, Sse2.Shuffle(vsum, 0b_11_10_11_10)); // add high to low

// Vector128<int>.ToScalar() isn't optimized pre-net5.0 https://github.com/dotnet/runtime/pull/37882
return Sse2.ConvertToInt32(vsum);
}
#endif
}
}
78 changes: 78 additions & 0 deletions src/ImageSharp/Formats/Png/Filters/AverageFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

#if SUPPORTS_RUNTIME_INTRINSICS
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
#endif

namespace SixLabors.ImageSharp.Formats.Png.Filters
{
/// <summary>
Expand Down Expand Up @@ -79,6 +84,79 @@ public static void Encode(Span<byte> scanline, Span<byte> previousScanline, Span
sum += Numerics.Abs(unchecked((sbyte)res));
}

#if SUPPORTS_RUNTIME_INTRINSICS
if (Avx2.IsSupported)
{
Vector256<byte> zero = Vector256<byte>.Zero;
Vector256<int> sumAccumulator = Vector256<int>.Zero;
Vector256<byte> allBitsSet = Avx2.CompareEqual(sumAccumulator, sumAccumulator).AsByte();

for (int xLeft = x - bytesPerPixel; x + Vector256<byte>.Count <= scanline.Length; xLeft += Vector256<byte>.Count)
{
Vector256<byte> scan = Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref scanBaseRef, x));
Vector256<byte> left = Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref scanBaseRef, xLeft));
Vector256<byte> above = Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref prevBaseRef, x));

Vector256<byte> avg = Avx2.Xor(Avx2.Average(Avx2.Xor(left, allBitsSet), Avx2.Xor(above, allBitsSet)), allBitsSet);
Vector256<byte> res = Avx2.Subtract(scan, avg);

Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref resultBaseRef, x + 1)) = res; // +1 to skip filter type
x += Vector256<byte>.Count;

sumAccumulator = Avx2.Add(sumAccumulator, Avx2.SumAbsoluteDifferences(Avx2.Abs(res.AsSByte()), zero).AsInt32());
}

sum += Numerics.EvenReduceSum(sumAccumulator);
}
else if (Sse2.IsSupported)
{
Vector128<sbyte> zero8 = Vector128<sbyte>.Zero;
Vector128<short> zero16 = Vector128<short>.Zero;
Vector128<int> sumAccumulator = Vector128<int>.Zero;
Vector128<byte> allBitsSet = Sse2.CompareEqual(sumAccumulator, sumAccumulator).AsByte();

for (int xLeft = x - bytesPerPixel; x + Vector128<byte>.Count <= scanline.Length; xLeft += Vector128<byte>.Count)
{
Vector128<byte> scan = Unsafe.As<byte, Vector128<byte>>(ref Unsafe.Add(ref scanBaseRef, x));
Vector128<byte> left = Unsafe.As<byte, Vector128<byte>>(ref Unsafe.Add(ref scanBaseRef, xLeft));
Vector128<byte> above = Unsafe.As<byte, Vector128<byte>>(ref Unsafe.Add(ref prevBaseRef, x));

Vector128<byte> avg = Sse2.Xor(Sse2.Average(Sse2.Xor(left, allBitsSet), Sse2.Xor(above, allBitsSet)), allBitsSet);
Vector128<byte> res = Sse2.Subtract(scan, avg);

Unsafe.As<byte, Vector128<byte>>(ref Unsafe.Add(ref resultBaseRef, x + 1)) = res; // +1 to skip filter type
x += Vector128<byte>.Count;

Vector128<sbyte> absRes;
if (Ssse3.IsSupported)
{
absRes = Ssse3.Abs(res.AsSByte()).AsSByte();
}
else
{
Vector128<sbyte> mask = Sse2.CompareGreaterThan(res.AsSByte(), zero8);
mask = Sse2.Xor(mask, allBitsSet.AsSByte());
absRes = Sse2.Xor(Sse2.Add(res.AsSByte(), mask), mask);
}

Vector128<short> loRes16 = Sse2.UnpackLow(absRes, zero8).AsInt16();
Vector128<short> hiRes16 = Sse2.UnpackHigh(absRes, zero8).AsInt16();

Vector128<int> loRes32 = Sse2.UnpackLow(loRes16, zero16).AsInt32();
Vector128<int> hiRes32 = Sse2.UnpackHigh(loRes16, zero16).AsInt32();
sumAccumulator = Sse2.Add(sumAccumulator, loRes32);
sumAccumulator = Sse2.Add(sumAccumulator, hiRes32);

loRes32 = Sse2.UnpackLow(hiRes16, zero16).AsInt32();
hiRes32 = Sse2.UnpackHigh(hiRes16, zero16).AsInt32();
sumAccumulator = Sse2.Add(sumAccumulator, loRes32);
sumAccumulator = Sse2.Add(sumAccumulator, hiRes32);
}

sum += Numerics.ReduceSum(sumAccumulator);
}
#endif

for (int xLeft = x - bytesPerPixel; x < scanline.Length; ++xLeft /* Note: ++x happens in the body to avoid one add operation */)
{
byte scan = Unsafe.Add(ref scanBaseRef, x);
Expand Down
118 changes: 118 additions & 0 deletions src/ImageSharp/Formats/Png/Filters/PaethFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
// Licensed under the Apache License, Version 2.0.

using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

#if SUPPORTS_RUNTIME_INTRINSICS
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
#endif

namespace SixLabors.ImageSharp.Formats.Png.Filters
{
/// <summary>
Expand Down Expand Up @@ -82,6 +88,53 @@ public static void Encode(Span<byte> scanline, Span<byte> previousScanline, Span
sum += Numerics.Abs(unchecked((sbyte)res));
}

#if SUPPORTS_RUNTIME_INTRINSICS
if (Avx2.IsSupported)
{
Vector256<byte> zero = Vector256<byte>.Zero;
Vector256<int> sumAccumulator = Vector256<int>.Zero;

for (int xLeft = x - bytesPerPixel; x + Vector256<byte>.Count <= scanline.Length; xLeft += Vector256<byte>.Count)
{
Vector256<byte> scan = Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref scanBaseRef, x));
Vector256<byte> left = Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref scanBaseRef, xLeft));
Vector256<byte> above = Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref prevBaseRef, x));
Vector256<byte> upperLeft = Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref prevBaseRef, xLeft));

Vector256<byte> res = Avx2.Subtract(scan, PaethPredictor(left, above, upperLeft));
Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref resultBaseRef, x + 1)) = res; // +1 to skip filter type
x += Vector256<byte>.Count;

sumAccumulator = Avx2.Add(sumAccumulator, Avx2.SumAbsoluteDifferences(Avx2.Abs(res.AsSByte()), zero).AsInt32());
}

sum += Numerics.EvenReduceSum(sumAccumulator);
}
else if (Vector.IsHardwareAccelerated)
{
Vector<uint> sumAccumulator = Vector<uint>.Zero;

for (int xLeft = x - bytesPerPixel; x + Vector<byte>.Count <= scanline.Length; xLeft += Vector<byte>.Count)
{
Vector<byte> scan = Unsafe.As<byte, Vector<byte>>(ref Unsafe.Add(ref scanBaseRef, x));
Vector<byte> left = Unsafe.As<byte, Vector<byte>>(ref Unsafe.Add(ref scanBaseRef, xLeft));
Vector<byte> above = Unsafe.As<byte, Vector<byte>>(ref Unsafe.Add(ref prevBaseRef, x));
Vector<byte> upperLeft = Unsafe.As<byte, Vector<byte>>(ref Unsafe.Add(ref prevBaseRef, xLeft));

Vector<byte> res = scan - PaethPredictor(left, above, upperLeft);
Unsafe.As<byte, Vector<byte>>(ref Unsafe.Add(ref resultBaseRef, x + 1)) = res; // +1 to skip filter type
x += Vector<byte>.Count;

Numerics.Accumulate(ref sumAccumulator, Vector.AsVectorByte(Vector.Abs(Vector.AsVectorSByte(res))));
}

for (int i = 0; i < Vector<uint>.Count; i++)
{
sum += (int)sumAccumulator[i];
}
}
#endif

for (int xLeft = x - bytesPerPixel; x < scanline.Length; ++xLeft /* Note: ++x happens in the body to avoid one add operation */)
{
byte scan = Unsafe.Add(ref scanBaseRef, x);
Expand Down Expand Up @@ -127,5 +180,70 @@ private static byte PaethPredictor(byte left, byte above, byte upperLeft)

return upperLeft;
}

#if SUPPORTS_RUNTIME_INTRINSICS
private static Vector256<byte> PaethPredictor(Vector256<byte> left, Vector256<byte> above, Vector256<byte> upleft)
{
Vector256<byte> zero = Vector256<byte>.Zero;

// Here, we refactor pa = abs(p - left) = abs(left + above - upleft - left)
// to pa = abs(above - upleft). Same deal for pb.
// Using saturated subtraction, if the result is negative, the output is zero.
// If we subtract in both directions and `or` the results, only one can be
// non-zero, so we end up with the absolute value.
Vector256<byte> sac = Avx2.SubtractSaturate(above, upleft);
Vector256<byte> sbc = Avx2.SubtractSaturate(left, upleft);
Vector256<byte> pa = Avx2.Or(Avx2.SubtractSaturate(upleft, above), sac);
Vector256<byte> pb = Avx2.Or(Avx2.SubtractSaturate(upleft, left), sbc);

// pc = abs(left + above - upleft - upleft), or abs(left - upleft + above - upleft).
// We've already calculated left - upleft and above - upleft in `sac` and `sbc`.
// If they are both negative or both positive, the absolute value of their
// sum can't possibly be less than `pa` or `pb`, so we'll never use the value.
// We make a mask that sets the value to 255 if they either both got
// saturated to zero or both didn't. Then we calculate the absolute value
// of their difference using saturated subtract and `or`, same as before,
// keeping the value only where the mask isn't set.
Vector256<byte> pm = Avx2.CompareEqual(Avx2.CompareEqual(sac, zero), Avx2.CompareEqual(sbc, zero));
Vector256<byte> pc = Avx2.Or(pm, Avx2.Or(Avx2.SubtractSaturate(pb, pa), Avx2.SubtractSaturate(pa, pb)));

// Finally, blend the values together. We start with `upleft` and overwrite on
// tied values so that the `left`, `above`, `upleft` precedence is preserved.
Vector256<byte> minbc = Avx2.Min(pc, pb);
Vector256<byte> resbc = Avx2.BlendVariable(upleft, above, Avx2.CompareEqual(minbc, pb));
return Avx2.BlendVariable(resbc, left, Avx2.CompareEqual(Avx2.Min(minbc, pa), pa));
}

private static Vector<byte> PaethPredictor(Vector<byte> left, Vector<byte> above, Vector<byte> upperLeft)
{
Vector.Widen(left, out Vector<ushort> a1, out Vector<ushort> a2);
Vector.Widen(above, out Vector<ushort> b1, out Vector<ushort> b2);
Vector.Widen(upperLeft, out Vector<ushort> c1, out Vector<ushort> c2);

Vector<short> p1 = PaethPredictor(Vector.AsVectorInt16(a1), Vector.AsVectorInt16(b1), Vector.AsVectorInt16(c1));
Vector<short> p2 = PaethPredictor(Vector.AsVectorInt16(a2), Vector.AsVectorInt16(b2), Vector.AsVectorInt16(c2));
return Vector.AsVectorByte(Vector.Narrow(p1, p2));
}

private static Vector<short> PaethPredictor(Vector<short> left, Vector<short> above, Vector<short> upperLeft)
{
Vector<short> p = left + above - upperLeft;
var pa = Vector.Abs(p - left);
var pb = Vector.Abs(p - above);
var pc = Vector.Abs(p - upperLeft);

var pa_pb = Vector.LessThanOrEqual(pa, pb);
var pa_pc = Vector.LessThanOrEqual(pa, pc);
var pb_pc = Vector.LessThanOrEqual(pb, pc);

return Vector.ConditionalSelect(
condition: Vector.BitwiseAnd(pa_pb, pa_pc),
left: left,
right: Vector.ConditionalSelect(
condition: pb_pc,
left: above,
right: upperLeft));
}
#endif
}
}
49 changes: 49 additions & 0 deletions src/ImageSharp/Formats/Png/Filters/SubFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
// Licensed under the Apache License, Version 2.0.

using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

#if SUPPORTS_RUNTIME_INTRINSICS
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
#endif

namespace SixLabors.ImageSharp.Formats.Png.Filters
{
/// <summary>
Expand Down Expand Up @@ -64,6 +70,49 @@ public static void Encode(Span<byte> scanline, Span<byte> result, int bytesPerPi
sum += Numerics.Abs(unchecked((sbyte)res));
}

#if SUPPORTS_RUNTIME_INTRINSICS
if (Avx2.IsSupported)
{
Vector256<byte> zero = Vector256<byte>.Zero;
Vector256<int> sumAccumulator = Vector256<int>.Zero;

for (int xLeft = x - bytesPerPixel; x + Vector256<byte>.Count <= scanline.Length; xLeft += Vector256<byte>.Count)
{
Vector256<byte> scan = Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref scanBaseRef, x));
Vector256<byte> prev = Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref scanBaseRef, xLeft));

Vector256<byte> res = Avx2.Subtract(scan, prev);
Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref resultBaseRef, x + 1)) = res; // +1 to skip filter type
x += Vector256<byte>.Count;

sumAccumulator = Avx2.Add(sumAccumulator, Avx2.SumAbsoluteDifferences(Avx2.Abs(res.AsSByte()), zero).AsInt32());
}

sum += Numerics.EvenReduceSum(sumAccumulator);
}
else if (Vector.IsHardwareAccelerated)
{
Vector<uint> sumAccumulator = Vector<uint>.Zero;

for (int xLeft = x - bytesPerPixel; x + Vector<byte>.Count <= scanline.Length; xLeft += Vector<byte>.Count)
{
Vector<byte> scan = Unsafe.As<byte, Vector<byte>>(ref Unsafe.Add(ref scanBaseRef, x));
Vector<byte> prev = Unsafe.As<byte, Vector<byte>>(ref Unsafe.Add(ref scanBaseRef, xLeft));

Vector<byte> res = scan - prev;
Unsafe.As<byte, Vector<byte>>(ref Unsafe.Add(ref resultBaseRef, x + 1)) = res; // +1 to skip filter type
x += Vector<byte>.Count;

Numerics.Accumulate(ref sumAccumulator, Vector.AsVectorByte(Vector.Abs(Vector.AsVectorSByte(res))));
}

for (int i = 0; i < Vector<uint>.Count; i++)
{
sum += (int)sumAccumulator[i];
}
}
#endif

for (int xLeft = x - bytesPerPixel; x < scanline.Length; ++xLeft /* Note: ++x happens in the body to avoid one add operation */)
{
byte scan = Unsafe.Add(ref scanBaseRef, x);
Expand Down
Loading