Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SsaForecast bug #5023

Merged
merged 3 commits into from
Apr 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions src/Microsoft.ML.TimeSeries/PolynomialUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ public FactorMultiplicity(int multiplicity = 1)
private sealed class PolynomialFactor
{
public List<decimal> Coefficients;
public static decimal[] Destination;

private decimal _key;
public decimal Key { get { return _key; } }
Expand Down Expand Up @@ -175,20 +174,21 @@ internal PolynomialFactor(decimal key)
_key = key;
}

public void Multiply(PolynomialFactor factor)
public void Multiply(PolynomialFactor factor, decimal[] destination)
{
var len = Coefficients.Count;
Coefficients.AddRange(factor.Coefficients);

PolynomialMultiplication(0, len, len, factor.Coefficients.Count, 0, 1, 1);
PolynomialMultiplication(destination, 0, len, len, factor.Coefficients.Count, 0, 1, 1);

for (var i = 0; i < Coefficients.Count; ++i)
Coefficients[i] = Destination[i];
Coefficients[i] = destination[i];

SetKey();
}

private void PolynomialMultiplication(int uIndex, int uLen, int vIndex, int vLen, int dstIndex, decimal uCoeff, decimal vCoeff)
private void PolynomialMultiplication(decimal[] destination, int uIndex, int uLen, int vIndex,
int vLen, int dstIndex, decimal uCoeff, decimal vCoeff)
{
Contracts.Assert(uIndex >= 0);
Contracts.Assert(uLen >= 1);
Expand All @@ -198,18 +198,19 @@ private void PolynomialMultiplication(int uIndex, int uLen, int vIndex, int vLen
Contracts.Assert(vIndex + vLen <= Utils.Size(Coefficients));
Contracts.Assert(uIndex + uLen <= vIndex || vIndex + vLen <= uIndex); // makes sure the input ranges are non-overlapping.
Contracts.Assert(dstIndex >= 0);
Contracts.Assert(dstIndex + uLen + vLen <= Utils.Size(Destination));
Contracts.Assert(dstIndex + uLen + vLen <= Utils.Size(destination));

if (uLen == 1 && vLen == 1)
{
Destination[dstIndex] = Coefficients[uIndex] * Coefficients[vIndex];
Destination[dstIndex + 1] = Coefficients[uIndex] + Coefficients[vIndex];
destination[dstIndex] = Coefficients[uIndex] * Coefficients[vIndex];
destination[dstIndex + 1] = Coefficients[uIndex] + Coefficients[vIndex];
}
else
NaivePolynomialMultiplication(uIndex, uLen, vIndex, vLen, dstIndex, uCoeff, vCoeff);
NaivePolynomialMultiplication(destination, uIndex, uLen, vIndex, vLen, dstIndex, uCoeff, vCoeff);
}

private void NaivePolynomialMultiplication(int uIndex, int uLen, int vIndex, int vLen, int dstIndex, decimal uCoeff, decimal vCoeff)
private void NaivePolynomialMultiplication(decimal[] destination, int uIndex, int uLen, int vIndex,
int vLen, int dstIndex, decimal uCoeff, decimal vCoeff)
{
int i;
int j;
Expand Down Expand Up @@ -246,7 +247,7 @@ private void NaivePolynomialMultiplication(int uIndex, int uLen, int vIndex, int
for (j = 0; j < a; ++j)
temp += (Coefficients[b - j + uIndex] * Coefficients[c + j + vIndex]);

Destination[i + dstIndex] = temp;
destination[i + dstIndex] = temp;
}
}
}
Expand Down Expand Up @@ -285,6 +286,7 @@ public static bool FindPolynomialCoefficients(Complex[] roots, ref Double[] coef
int destinationOffset = 0;

var factors = new List<PolynomialFactor>();
decimal[] destination = new decimal[n];

for (i = 0; i < n; ++i)
{
Expand Down Expand Up @@ -336,9 +338,6 @@ public static bool FindPolynomialCoefficients(Complex[] roots, ref Double[] coef

if (destinationOffset < n - 1)
{
if (Utils.Size(PolynomialFactor.Destination) < n)
PolynomialFactor.Destination = new decimal[n];

while (factors.Count > 1)
{
var k1 = Math.Abs(factors.ElementAt(0).Key);
Expand All @@ -364,7 +363,7 @@ public static bool FindPolynomialCoefficients(Complex[] roots, ref Double[] coef
var f2 = factors.ElementAt(ind);
factors.RemoveAt(ind);

f1.Multiply(f2);
f1.Multiply(f2, destination);

ind = factors.BinarySearch(f1, comparer);
if (ind >= 0)
Expand Down
2 changes: 0 additions & 2 deletions test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,6 @@ public void ChangePointDetectionWithSeasonalityPredictionEngine()
}

[LessThanNetCore30OrNotNetCoreFact("netcoreapp3.1 output differs from Baseline")]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void SsaForecast()
{
var env = new MLContext(1);
Expand Down