Skip to content

Add AVX2 versions of CombinedShannonEntropy #1848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 27, 2021
Merged
20 changes: 20 additions & 0 deletions src/ImageSharp/Common/Helpers/Numerics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,26 @@ public static int ReduceSum(Vector128<int> accumulator)
}
}

/// <summary>
/// Reduces elements of the vector into one sum.
/// </summary>
/// <param name="accumulator">The accumulator to reduce.</param>
/// <returns>The sum of all elements.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int ReduceSum(Vector256<int> accumulator)
{
// Add upper lane to lower lane.
Vector128<int> vsum = Sse2.Add(accumulator.GetLower(), accumulator.GetUpper());

// Add odd to even.
vsum = Sse2.Add(vsum, Sse2.Shuffle(vsum, 0b_11_11_01_01));

// Add high to low.
vsum = Sse2.Add(vsum, Sse2.Shuffle(vsum, 0b_11_10_11_10));

return Sse2.ConvertToInt32(vsum);
}

/// <summary>
/// Reduces even elements of the vector into one sum.
/// </summary>
Expand Down
192 changes: 175 additions & 17 deletions src/ImageSharp/Formats/Webp/Lossless/LosslessUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0.

using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using SixLabors.ImageSharp.Memory;
Expand Down Expand Up @@ -761,28 +762,184 @@ public static void BundleColorMap(Span<byte> row, int width, int xBits, Span<uin
/// <returns>Shanon entropy.</returns>
public static float CombinedShannonEntropy(Span<int> x, Span<int> y)
{
double retVal = 0.0d;
uint sumX = 0, sumXY = 0;
for (int i = 0; i < 256; i++)
#if SUPPORTS_RUNTIME_INTRINSICS
if (Avx2.IsSupported)
{
uint xi = (uint)x[i];
if (xi != 0)
double retVal = 0.0d;
Vector256<int> tmp = Vector256<int>.Zero; // has the size of the scratch space of sizeof(int) * 8
ref int xRef = ref MemoryMarshal.GetReference(x);
ref int yRef = ref MemoryMarshal.GetReference(y);
Vector256<int> sumXY256 = Vector256<int>.Zero;
Vector256<int> sumX256 = Vector256<int>.Zero;
ref int tmpRef = ref Unsafe.As<Vector256<int>, int>(ref tmp);
for (nint i = 0; i < 256; i += 8)
{
uint xy = xi + (uint)y[i];
sumX += xi;
retVal -= FastSLog2(xi);
sumXY += xy;
retVal -= FastSLog2(xy);
Vector256<int> xVec = Unsafe.As<int, Vector256<int>>(ref Unsafe.Add(ref xRef, i));
Vector256<int> yVec = Unsafe.As<int, Vector256<int>>(ref Unsafe.Add(ref yRef, i));

// Check if any X is non-zero: this actually provides a speedup as X is usually sparse.
int mask = Avx2.MoveMask(Avx2.CompareEqual(xVec, Vector256<int>.Zero).AsByte());
if (mask != -1)
{
Vector256<int> xy256 = Avx2.Add(xVec, yVec);
sumXY256 = Avx2.Add(sumXY256, xy256);
sumX256 = Avx2.Add(sumX256, xVec);

// Analyze the different X + Y.
Unsafe.As<int, Vector256<int>>(ref tmpRef) = xy256;
if (tmpRef != 0)
{
retVal -= FastSLog2((uint)tmpRef);
if (Unsafe.Add(ref xRef, i) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i));
}
}

if (Unsafe.Add(ref tmpRef, 1) != 0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I have tried to put those repeating if statements into own method calls, but profiling has shown that this makes it actually slower even with Aggressive Inlining.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't it be a loop? It looks incremental 0-7 to me.

However, my money says there's something clever that can be done with masking here to determine if each element != 0 and apply the diff as a single operation. I'm sure I've seen similar before.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simd? You can check each element in a vector without if checks at all for tmpRef and xRef vectors.

You can also remove if checks with log2 precalculation for each case and simply multiplying that log2 vectors with comparison mask values. This may or may not be faster than if-checks - depends ong FastSLog2 implementation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original SSE2 version of this uses macros. I tried to keep it as similar as possible.
Because of that, I thought a loop might be also slower, but I will give it a try and test it.

FastSLog2 is in most cases (>90% just a lookup table), but still most of the time is spend here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess loop can further degrade branch predictor history. Masking is still possible but that would require some serious code rewrite for AVX branch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have done profile tests now with loops several times and it makes it noticeable slower.

I am not really sure that it is worth trying to optimize it further, if it would take a huge amount of work to do so. As can be seen here:
shannon_avx2

The main time consumer is FastSLog2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly have no idea how webp works but as far as I understand checked value is actually pretty random, adding for-loop on top of it may screw up branch predictor with yet another stable (always true for 8 iterations) if-check - that's what's most likely causing performance drop.

All of extra if-checks can be removed via simd masks but it's a question whether it'd be faster than current implementation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a previous SSE2 version which used masks, if you look at commit: ed8bd61, but this was not better then the current approach. Maybe, if we could create a AVX version of that, it could be better, but I am also not sure if it really can beat the current one.

Copy link
Contributor

@br3aker br3aker Nov 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually easier to describe my proposal in code than in human language. I'll probably try to implement it after this gets merged.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@br3aker thanks, but my advice would be, to only do it, if you think it would be easy for you and not to much work. As I said before, I am really unsure how much we can gain from this.

{
retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 1));
if (Unsafe.Add(ref xRef, i + 1) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 1));
}
}

if (Unsafe.Add(ref tmpRef, 2) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 2));
if (Unsafe.Add(ref xRef, i + 2) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 2));
}
}

if (Unsafe.Add(ref tmpRef, 3) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 3));
if (Unsafe.Add(ref xRef, i + 3) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 3));
}
}

if (Unsafe.Add(ref tmpRef, 4) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 4));
if (Unsafe.Add(ref xRef, i + 4) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 4));
}
}

if (Unsafe.Add(ref tmpRef, 5) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 5));
if (Unsafe.Add(ref xRef, i + 5) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 5));
}
}

if (Unsafe.Add(ref tmpRef, 6) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 6));
if (Unsafe.Add(ref xRef, i + 6) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 6));
}
}

if (Unsafe.Add(ref tmpRef, 7) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 7));
if (Unsafe.Add(ref xRef, i + 7) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 7));
}
}
}
else
{
// X is fully 0, so only deal with Y.
sumXY256 = Avx2.Add(sumXY256, yVec);

if (Unsafe.Add(ref yRef, i) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i));
}

if (Unsafe.Add(ref yRef, i + 1) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 1));
}

if (Unsafe.Add(ref yRef, i + 2) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 2));
}

if (Unsafe.Add(ref yRef, i + 3) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 3));
}

if (Unsafe.Add(ref yRef, i + 4) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 4));
}

if (Unsafe.Add(ref yRef, i + 5) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 5));
}

if (Unsafe.Add(ref yRef, i + 6) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 6));
}

if (Unsafe.Add(ref yRef, i + 7) != 0)
{
retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 7));
}
}
}
else if (y[i] != 0)

// Sum up sumX256 to get sumX and sum up sumXY256 to get sumXY.
int sumX = Numerics.ReduceSum(sumX256);
int sumXY = Numerics.ReduceSum(sumXY256);

retVal += FastSLog2((uint)sumX) + FastSLog2((uint)sumXY);

return (float)retVal;
}
else
#endif
{
double retVal = 0.0d;
uint sumX = 0, sumXY = 0;
for (int i = 0; i < 256; i++)
{
sumXY += (uint)y[i];
retVal -= FastSLog2((uint)y[i]);
uint xi = (uint)x[i];
if (xi != 0)
{
uint xy = xi + (uint)y[i];
sumX += xi;
retVal -= FastSLog2(xi);
sumXY += xy;
retVal -= FastSLog2(xy);
}
else if (y[i] != 0)
{
sumXY += (uint)y[i];
retVal -= FastSLog2((uint)y[i]);
}
}
}

retVal += FastSLog2(sumX) + FastSLog2(sumXY);
return (float)retVal;
retVal += FastSLog2(sumX) + FastSLog2(sumXY);
return (float)retVal;
}
}

[MethodImpl(InliningOptions.ShortMethod)]
Expand Down Expand Up @@ -838,6 +995,7 @@ public static void ColorCodeToMultipliers(uint colorCode, ref Vp8LMultipliers m)
private static float FastSLog2Slow(uint v)
{
DebugGuard.MustBeGreaterThanOrEqualTo<uint>(v, LogLookupIdxMax, nameof(v));

if (v < ApproxLogWithCorrectionMax)
{
int logCnt = 0;
Expand Down Expand Up @@ -867,7 +1025,7 @@ private static float FastSLog2Slow(uint v)

private static float FastLog2Slow(uint v)
{
Guard.MustBeGreaterThanOrEqualTo(v, LogLookupIdxMax, nameof(v));
DebugGuard.MustBeGreaterThanOrEqualTo<uint>(v, LogLookupIdxMax, nameof(v));

if (v < ApproxLogWithCorrectionMax)
{
Expand Down
28 changes: 24 additions & 4 deletions tests/ImageSharp.Tests/Formats/WebP/LosslessUtilsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ namespace SixLabors.ImageSharp.Tests.Formats.Webp
[Trait("Format", "Webp")]
public class LosslessUtilsTests
{
private static void RunCombinedShannonEntropyTest()
{
int[] x = { 3, 5, 2, 5, 3, 1, 2, 2, 3, 3, 1, 2, 1, 2, 1, 1, 0, 0, 0, 1, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 1, 1, 0, 0, 2, 1, 1, 0, 3, 1, 2, 3, 2, 3 };
int[] y = { 11, 12, 8, 3, 4, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 2, 1, 1, 2, 4, 6, 4 };
float expected = 884.7585f;

float actual = LosslessUtils.CombinedShannonEntropy(x, y);

Assert.Equal(expected, actual, 5);
}

private static void RunSubtractGreenTest()
{
uint[] pixelData =
Expand Down Expand Up @@ -193,6 +204,9 @@ private static void RunPredictor13Test()
}
}

[Fact]
public void CombinedShannonEntropy_Works() => RunCombinedShannonEntropyTest();

[Fact]
public void Predictor11_Works() => RunPredictor11Test();

Expand All @@ -216,6 +230,12 @@ private static void RunPredictor13Test()

#if SUPPORTS_RUNTIME_INTRINSICS

[Fact]
public void CombinedShannonEntropy_WithHardwareIntrinsics_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunCombinedShannonEntropyTest, HwIntrinsics.AllowAll);

[Fact]
public void CombinedShannonEntropy_WithoutAVX2_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunCombinedShannonEntropyTest, HwIntrinsics.DisableAVX2);

[Fact]
public void Predictor11_WithHardwareIntrinsics_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunPredictor11Test, HwIntrinsics.AllowAll);

Expand All @@ -238,19 +258,19 @@ private static void RunPredictor13Test()
public void SubtractGreen_WithHardwareIntrinsics_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunSubtractGreenTest, HwIntrinsics.AllowAll);

[Fact]
public void SubtractGreen_WithoutAvx_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunSubtractGreenTest, HwIntrinsics.DisableAVX);
public void SubtractGreen_WithoutAVX2_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunSubtractGreenTest, HwIntrinsics.DisableAVX2);

[Fact]
public void SubtractGreen_WithoutAvxOrSSSE3_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunSubtractGreenTest, HwIntrinsics.DisableAVX | HwIntrinsics.DisableSSSE3);
public void SubtractGreen_WithoutAvxOrSSSE3_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunSubtractGreenTest, HwIntrinsics.DisableAVX2 | HwIntrinsics.DisableSSSE3);

[Fact]
public void AddGreenToBlueAndRed_WithHardwareIntrinsics_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunAddGreenToBlueAndRedTest, HwIntrinsics.AllowAll);

[Fact]
public void AddGreenToBlueAndRed_WithoutAvx_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunAddGreenToBlueAndRedTest, HwIntrinsics.DisableAVX);
public void AddGreenToBlueAndRed_WithoutAVX2_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunAddGreenToBlueAndRedTest, HwIntrinsics.DisableAVX2);

[Fact]
public void AddGreenToBlueAndRed_WithoutAvxOrSSSE3_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunAddGreenToBlueAndRedTest, HwIntrinsics.DisableAVX | HwIntrinsics.DisableSSE2 | HwIntrinsics.DisableSSSE3);
public void AddGreenToBlueAndRed_WithoutAVX2OrSSSE3_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunAddGreenToBlueAndRedTest, HwIntrinsics.DisableAVX2 | HwIntrinsics.DisableSSE2 | HwIntrinsics.DisableSSSE3);

[Fact]
public void TransformColor_WithHardwareIntrinsics_Works() => FeatureTestRunner.RunWithHwIntrinsicsFeature(RunTransformColorTest, HwIntrinsics.AllowAll);
Expand Down