Skip to content

Commit

Permalink
Fix SsaForecast bug (#5023)
Browse files Browse the repository at this point in the history
* fix SsaForecast

* refind comments

* remove static fields to local variable
  • Loading branch information
frank-dong-ms-zz authored Apr 14, 2020
1 parent fbea448 commit 38f34b8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
29 changes: 14 additions & 15 deletions src/Microsoft.ML.TimeSeries/PolynomialUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ public FactorMultiplicity(int multiplicity = 1)
private sealed class PolynomialFactor
{
public List<decimal> Coefficients;
public static decimal[] Destination;

private decimal _key;
public decimal Key { get { return _key; } }
Expand Down Expand Up @@ -175,20 +174,21 @@ internal PolynomialFactor(decimal key)
_key = key;
}

public void Multiply(PolynomialFactor factor)
public void Multiply(PolynomialFactor factor, decimal[] destination)
{
var len = Coefficients.Count;
Coefficients.AddRange(factor.Coefficients);

PolynomialMultiplication(0, len, len, factor.Coefficients.Count, 0, 1, 1);
PolynomialMultiplication(destination, 0, len, len, factor.Coefficients.Count, 0, 1, 1);

for (var i = 0; i < Coefficients.Count; ++i)
Coefficients[i] = Destination[i];
Coefficients[i] = destination[i];

SetKey();
}

private void PolynomialMultiplication(int uIndex, int uLen, int vIndex, int vLen, int dstIndex, decimal uCoeff, decimal vCoeff)
private void PolynomialMultiplication(decimal[] destination, int uIndex, int uLen, int vIndex,
int vLen, int dstIndex, decimal uCoeff, decimal vCoeff)
{
Contracts.Assert(uIndex >= 0);
Contracts.Assert(uLen >= 1);
Expand All @@ -198,18 +198,19 @@ private void PolynomialMultiplication(int uIndex, int uLen, int vIndex, int vLen
Contracts.Assert(vIndex + vLen <= Utils.Size(Coefficients));
Contracts.Assert(uIndex + uLen <= vIndex || vIndex + vLen <= uIndex); // makes sure the input ranges are non-overlapping.
Contracts.Assert(dstIndex >= 0);
Contracts.Assert(dstIndex + uLen + vLen <= Utils.Size(Destination));
Contracts.Assert(dstIndex + uLen + vLen <= Utils.Size(destination));

if (uLen == 1 && vLen == 1)
{
Destination[dstIndex] = Coefficients[uIndex] * Coefficients[vIndex];
Destination[dstIndex + 1] = Coefficients[uIndex] + Coefficients[vIndex];
destination[dstIndex] = Coefficients[uIndex] * Coefficients[vIndex];
destination[dstIndex + 1] = Coefficients[uIndex] + Coefficients[vIndex];
}
else
NaivePolynomialMultiplication(uIndex, uLen, vIndex, vLen, dstIndex, uCoeff, vCoeff);
NaivePolynomialMultiplication(destination, uIndex, uLen, vIndex, vLen, dstIndex, uCoeff, vCoeff);
}

private void NaivePolynomialMultiplication(int uIndex, int uLen, int vIndex, int vLen, int dstIndex, decimal uCoeff, decimal vCoeff)
private void NaivePolynomialMultiplication(decimal[] destination, int uIndex, int uLen, int vIndex,
int vLen, int dstIndex, decimal uCoeff, decimal vCoeff)
{
int i;
int j;
Expand Down Expand Up @@ -246,7 +247,7 @@ private void NaivePolynomialMultiplication(int uIndex, int uLen, int vIndex, int
for (j = 0; j < a; ++j)
temp += (Coefficients[b - j + uIndex] * Coefficients[c + j + vIndex]);

Destination[i + dstIndex] = temp;
destination[i + dstIndex] = temp;
}
}
}
Expand Down Expand Up @@ -285,6 +286,7 @@ public static bool FindPolynomialCoefficients(Complex[] roots, ref Double[] coef
int destinationOffset = 0;

var factors = new List<PolynomialFactor>();
decimal[] destination = new decimal[n];

for (i = 0; i < n; ++i)
{
Expand Down Expand Up @@ -336,9 +338,6 @@ public static bool FindPolynomialCoefficients(Complex[] roots, ref Double[] coef

if (destinationOffset < n - 1)
{
if (Utils.Size(PolynomialFactor.Destination) < n)
PolynomialFactor.Destination = new decimal[n];

while (factors.Count > 1)
{
var k1 = Math.Abs(factors.ElementAt(0).Key);
Expand All @@ -364,7 +363,7 @@ public static bool FindPolynomialCoefficients(Complex[] roots, ref Double[] coef
var f2 = factors.ElementAt(ind);
factors.RemoveAt(ind);

f1.Multiply(f2);
f1.Multiply(f2, destination);

ind = factors.BinarySearch(f1, comparer);
if (ind >= 0)
Expand Down
2 changes: 0 additions & 2 deletions test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,6 @@ public void ChangePointDetectionWithSeasonalityPredictionEngine()
}

[LessThanNetCore30OrNotNetCoreFact("netcoreapp3.1 output differs from Baseline")]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void SsaForecast()
{
var env = new MLContext(1);
Expand Down

0 comments on commit 38f34b8

Please sign in to comment.