Skip to content

Commit b5b9ed6

Browse files
committed
Update tests.
1 parent fae9570 commit b5b9ed6

File tree

8 files changed

+94
-67
lines changed

8 files changed

+94
-67
lines changed

src/Microsoft.ML.Data/MLContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ public sealed class MLContext : IHostEnvironment
4646
/// </summary>
4747
public AnomalyDetectionCatalog AnomalyDetection { get; }
4848

49+
/// <summary>
50+
/// Trainers and tasks specific to forecasting problems.
51+
/// </summary>
52+
public ForecastingCatalog Forecasting { get; }
53+
4954
/// <summary>
5055
/// Data processing operations.
5156
/// </summary>
@@ -113,6 +118,7 @@ public MLContext(int? seed = null)
113118
Clustering = new ClusteringCatalog(_env);
114119
Ranking = new RankingCatalog(_env);
115120
AnomalyDetection = new AnomalyDetectionCatalog(_env);
121+
Forecasting = new ForecastingCatalog(_env);
116122
Transforms = new TransformsCatalog(_env);
117123
Model = new ModelOperationsCatalog(_env);
118124
Data = new DataOperationsCatalog(_env);

src/Microsoft.ML.Data/TrainCatalog.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,4 +704,31 @@ public AnomalyDetectionMetrics Evaluate(IDataView data, string labelColumnName =
704704
return eval.Evaluate(data, labelColumnName, scoreColumnName, predictedLabelColumnName);
705705
}
706706
}
707+
708+
/// <summary>
709+
/// Class used by <see cref="MLContext"/> to create instances of forecasting components.
710+
/// </summary>
711+
public sealed class ForecastingCatalog : TrainCatalogBase
712+
{
713+
/// <summary>
714+
/// The list of trainers for performing forecasting.
715+
/// </summary>
716+
public Forecasters Trainers { get; }
717+
718+
internal ForecastingCatalog(IHostEnvironment env) : base(env, nameof(ForecastingCatalog))
719+
{
720+
Trainers = new Forecasters(this);
721+
}
722+
723+
/// <summary>
724+
/// Class used by <see cref="MLContext"/> to create instances of forecasting trainers.
725+
/// </summary>
726+
public sealed class Forecasters : CatalogInstantiatorBase
727+
{
728+
internal Forecasters(ForecastingCatalog catalog)
729+
: base(catalog)
730+
{
731+
}
732+
}
733+
}
707734
}

src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ public static SrCnnAnomalyEstimator DetectAnomalyBySrCnn(this TransformsCatalog
179179
/// </format>
180180
/// </example>
181181
public static SsaForecastingEstimator ForecastBySsa(
182-
this TransformsCatalog catalog, string outputColumnName, string inputColumnName, int windowSize, int seriesLength, int trainSize, int horizon,
182+
this ForecastingCatalog catalog, string outputColumnName, string inputColumnName, int windowSize, int seriesLength, int trainSize, int horizon,
183183
bool isAdaptive = false, float discountFactor = 1, RankSelectionMethod rankSelectionMethod = RankSelectionMethod.Exact, int? rank = null,
184184
int? maxRank = null, bool shouldStablize = true, bool shouldMaintainInfo = false, GrowthRatio? maxGrowth = null, string forecastingConfidenceIntervalMinOutputColumnName = null,
185185
string forecastingConfidenceIntervalMaxOutputColumnName = null, float confidenceLevel = 0.95f) =>

src/Microsoft.ML.TimeSeries/SSaForecasting.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat
224224
/// ]]>
225225
/// </format>
226226
/// </remarks>
227-
/// <seealso cref="Microsoft.ML.TimeSeriesCatalog.ForecastBySsa(TransformsCatalog, string, string, int, int, int, int, bool, float, RankSelectionMethod, int?, int?, bool, bool, GrowthRatio?, string, string, float)" />
228227
public sealed class SsaForecastingEstimator : IEstimator<SsaForecasting>
229228
{
230229
private readonly IHost _host;

src/Microsoft.ML.TimeSeries/SequentialForecastingTransformBase.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,13 @@ private protected sealed override void InitializeStateCore(bool disk = false)
256256
{
257257
Parent = (SequentialForecastingTransformBase<TInput, TState>)ParentTransform;
258258
Host.Assert(WindowSize >= 0);
259-
InitializeAnomalyDetector();
259+
InitializeForecaster();
260260
}
261261

262262
/// <summary>
263-
/// The abstract method that realizes the initialization functionality for the anomaly detector.
263+
/// The abstract method that realizes the initialization functionality for the forecaster.
264264
/// </summary>
265-
private protected abstract void InitializeAnomalyDetector();
265+
private protected abstract void InitializeForecaster();
266266
}
267267
}
268268
}

src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ public void UpdateStateCore(ref TInput input, bool buffer = true)
168168

169169
public void Process(ref TInput input, ref TOutput output1, ref TOutput output2, ref TOutput output3)
170170
{
171+
//Using prediction engine will not evaluate the below condition to true.
171172
if (PreviousPosition == -1)
172173
UpdateStateCore(ref input);
173174

@@ -186,8 +187,9 @@ public void Process(ref TInput input, ref TOutput output1, ref TOutput output2,
186187

187188
public void ProcessWithoutBuffer(ref TInput input, ref TOutput output1, ref TOutput output2, ref TOutput output3)
188189
{
190+
//Using prediction engine will not evaluate the below condition to true.
189191
if (PreviousPosition == -1)
190-
UpdateStateCore(ref input, false);
192+
UpdateStateCore(ref input);
191193

192194
if (InitialWindowedBuffer.Count < InitialWindowSize)
193195
{
@@ -202,6 +204,7 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output1, ref TOut
202204

203205
public void Process(ref TInput input, ref TOutput output)
204206
{
207+
//Using prediction engine will not evaluate the below condition to true.
205208
if (PreviousPosition == -1)
206209
UpdateStateCore(ref input);
207210

@@ -220,8 +223,9 @@ public void Process(ref TInput input, ref TOutput output)
220223

221224
public void ProcessWithoutBuffer(ref TInput input, ref TOutput output)
222225
{
226+
//Using prediction engine will not evaluate the below condition to true.
223227
if (PreviousPosition == -1)
224-
UpdateStateCore(ref input, false);
228+
UpdateStateCore(ref input);
225229

226230
if (InitialWindowedBuffer.Count < InitialWindowSize)
227231
{

src/Microsoft.ML.TimeSeries/SsaForecastingBase.cs

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ internal void SaveThis(ModelSaveContext ctx)
250250
internal sealed class State : ForecastingStateBase
251251
{
252252
private SequenceModelerBase<Single, Single> _model;
253-
private SsaForecastingBase _parentAnomalyDetector;
253+
private SsaForecastingBase _parentForecaster;
254254

255255
public State()
256256
{
@@ -278,8 +278,8 @@ private protected override void CloneCore(State state)
278278
stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone();
279279
if (_model != null)
280280
{
281-
_parentAnomalyDetector.Model = _parentAnomalyDetector.Model.Clone();
282-
_model = _parentAnomalyDetector.Model;
281+
_parentForecaster.Model = _parentForecaster.Model.Clone();
282+
_model = _parentForecaster.Model;
283283
}
284284
}
285285

@@ -288,30 +288,40 @@ private protected override void LearnStateFromDataCore(FixedSizeQueue<Single> da
288288
// This method is empty because there is no need to implement a training logic here.
289289
}
290290

291-
private protected override void InitializeAnomalyDetector()
291+
private protected override void InitializeForecaster()
292292
{
293-
_parentAnomalyDetector = (SsaForecastingBase)Parent;
294-
_model = _parentAnomalyDetector.Model;
293+
_parentForecaster = (SsaForecastingBase)Parent;
294+
_model = _parentForecaster.Model;
295295
}
296296

297297
private protected override void TransformCore(ref float input, FixedSizeQueue<float> windowedBuffer, long iteration, ref VBuffer<float> dst)
298298
{
299-
dst = new VBuffer<float>(_parentAnomalyDetector.Horizon,
300-
((AdaptiveSingularSpectrumSequenceModelerInternal)_model).Forecast(_parentAnomalyDetector.Horizon));
299+
// Forecasting is being done without prediction engine. Update the model
300+
// with the observation.
301+
if (PreviousPosition == -1)
302+
Consume(input);
303+
304+
dst = new VBuffer<float>(_parentForecaster.Horizon,
305+
((AdaptiveSingularSpectrumSequenceModelerInternal)_model).Forecast(_parentForecaster.Horizon));
301306
}
302307

303308
private protected override void TransformCore(ref float input, FixedSizeQueue<float> windowedBuffer, long iteration,
304309
ref VBuffer<float> dst1, ref VBuffer<float> dst2, ref VBuffer<float> dst3)
305310
{
306-
((AdaptiveSingularSpectrumSequenceModelerInternal)_model).ForecastWithConfidenceIntervals(_parentAnomalyDetector.Horizon,
307-
out float[] forecast, out float[] min, out float[] max, _parentAnomalyDetector.ConfidenceLevel);
311+
// Forecasting is being done without prediction engine. Update the model
312+
// with the observation.
313+
if (PreviousPosition == -1)
314+
Consume(input);
315+
316+
((AdaptiveSingularSpectrumSequenceModelerInternal)_model).ForecastWithConfidenceIntervals(_parentForecaster.Horizon,
317+
out float[] forecast, out float[] min, out float[] max, _parentForecaster.ConfidenceLevel);
308318

309-
dst1 = new VBuffer<float>(_parentAnomalyDetector.Horizon, forecast);
310-
dst2 = new VBuffer<float>(_parentAnomalyDetector.Horizon, min);
311-
dst3 = new VBuffer<float>(_parentAnomalyDetector.Horizon, max);
319+
dst1 = new VBuffer<float>(_parentForecaster.Horizon, forecast);
320+
dst2 = new VBuffer<float>(_parentForecaster.Horizon, min);
321+
dst3 = new VBuffer<float>(_parentForecaster.Horizon, max);
312322
}
313323

314-
public override void Consume(Single input) => _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive);
324+
public override void Consume(Single input) => _model.Consume(ref input, _parentForecaster.IsAdaptive);
315325
}
316326
}
317327
}

test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ private sealed class ForecastPrediction
2626
{
2727
#pragma warning disable CS0649
2828
[VectorType(4)]
29-
public float[] Change;
29+
public float[] Forecast;
3030
[VectorType(4)]
31-
public float[] Min;
31+
public float[] MinCnf;
3232
[VectorType(4)]
33-
public float[] Max;
33+
public float[] MaxCnf;
3434
#pragma warning restore CS0649
3535
}
3636

@@ -182,13 +182,14 @@ public void Forecast()
182182
{
183183
ConfidenceLevel = 0.95f,
184184
Source = "Value",
185-
Name = "Change",
186-
ForecastingConfidenceIntervalMinOutputColumnName = "Min",
187-
ForecastingConfidenceIntervalMaxOutputColumnName = "Max",
185+
Name = "Forecast",
186+
ForecastingConfidenceIntervalMinOutputColumnName = "MinCnf",
187+
ForecastingConfidenceIntervalMaxOutputColumnName = "MaxCnf",
188188
WindowSize = 10,
189189
SeriesLength = 11,
190190
TrainSize = 22,
191-
Horizon = 4
191+
Horizon = 4,
192+
IsAdaptive = true
192193
};
193194

194195
for (int j = 0; j < NumberOfSeasonsInTraining; j++)
@@ -205,18 +206,19 @@ public void Forecast()
205206
// Get predictions
206207
var enumerator = env.Data.CreateEnumerable<ForecastPrediction>(output, true).GetEnumerator();
207208
ForecastPrediction row = null;
208-
List<double> expectedValues = new List<double>() { 0, -3.31410598754883, 0.5, 5.12000000000001E-08, 0, 1.5700820684432983, 5.2001145245395008E-07,
209-
0.012414560443710681, 0, 1.2854313254356384, 0.28810801662678009, 0.02038940454467935, 0, -1.0950627326965332, 0.36663890634019225, 0.026956459625565483};
209+
List<float> expectedForecast = new List<float>() { 0.191491723f, 2.53994083f, 5.26454258f, 7.37313938f };
210+
List<float> minCnf = new List<float>() { -3.9741993f, -2.36872721f, 0.09407653f, 2.18899345f };
211+
List<float> maxCnf = new List<float>() { 4.3571825f, 7.448609f, 10.435009f, 12.5572853f };
212+
enumerator.MoveNext();
213+
row = enumerator.Current;
210214

211-
int index = 0;
212-
while (enumerator.MoveNext() && index < expectedValues.Count)
215+
for (int localIndex = 0; localIndex < 4; localIndex++)
213216
{
214-
row = enumerator.Current;
215-
/*Assert.Equal(expectedValues[index++], row.Change[0], precision: 7); // Alert
216-
Assert.Equal(expectedValues[index++], row.Change[1], precision: 7); // Raw score
217-
Assert.Equal(expectedValues[index++], row.Change[2], precision: 7); // P-Value score
218-
Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); // Martingale score*/
217+
Assert.Equal(expectedForecast[localIndex], row.Forecast[localIndex], precision: 7);
218+
Assert.Equal(minCnf[localIndex], row.MinCnf[localIndex], precision: 7);
219+
Assert.Equal(maxCnf[localIndex], row.MaxCnf[localIndex], precision: 7);
219220
}
221+
220222
}
221223

222224
[LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")]
@@ -384,7 +386,7 @@ public void ForecastingPredictionEngine()
384386
WindowSize = 10,
385387
SeriesLength = 11,
386388
TrainSize = 22,
387-
Horizon = 4
389+
Horizon = 4,
388390
};
389391

390392
for (int j = 0; j < NumberOfSeasonsInTraining; j++)
@@ -402,37 +404,16 @@ public void ForecastingPredictionEngine()
402404
engine.Update(new Data(1));
403405
engine.Update(new Data(2));
404406
var forecast = engine.Forecast(horizon: 5);
405-
var prediction = ml.Data.CreateEnumerable<ForecastResult>(forecast, reuseRowObject: false);
406-
407-
408-
/*Assert.Equal(0, prediction.Change[0], precision: 7); // Alert
409-
Assert.Equal(1.1661833524703979, prediction.Change[1], precision: 5); // Raw score
410-
Assert.Equal(0.5, prediction.Change[2], precision: 7); // P-Value score
411-
Assert.Equal(5.1200000000000114E-08, prediction.Change[3], precision: 7); // Martingale score
412-
413-
//Model 1: Checkpoint.
414-
var modelPath = "temp.zip";
415-
engine.CheckPoint(ml, modelPath);
416-
417-
//Model 1: Prediction #2
418-
prediction = engine.Predict(new Data(1));
419-
Assert.Equal(0, prediction.Change[0], precision: 7); // Alert
420-
Assert.Equal(0.12216401100158691, prediction.Change[1], precision: 5); // Raw score
421-
Assert.Equal(0.14823824685192111, prediction.Change[2], precision: 5); // P-Value score
422-
Assert.Equal(1.5292508189989167E-07, prediction.Change[3], precision: 7); // Martingale score
407+
var prediction = ml.Data.CreateEnumerable<ForecastResult>(forecast, reuseRowObject: false).GetEnumerator();
423408

424-
// Load Model 1.
425-
ITransformer model2 = null;
426-
using (var file = File.OpenRead(modelPath))
427-
model2 = ml.Model.Load(file, out var schema);
409+
int index = 0;
410+
List<float> expectedForecast = new List<float>() { 3.9516573f, 6.212672f, 7.732854f, 8.125769f, 7.22453928f };
411+
while (prediction.MoveNext())
412+
{
413+
Assert.Equal(prediction.Current.Forecast, expectedForecast[index++]);
414+
}
428415

429-
//Predict and expect the same result after checkpointing(Prediction #2).
430-
engine = model2.CreateTimeSeriesPredictionFunction<Data, Prediction>(ml);
431-
prediction = engine.Predict(new Data(1));
432-
Assert.Equal(0, prediction.Change[0], precision: 7); // Alert
433-
Assert.Equal(0.12216401100158691, prediction.Change[1], precision: 5); // Raw score
434-
Assert.Equal(0.14823824685192111, prediction.Change[2], precision: 5); // P-Value score
435-
Assert.Equal(1.5292508189989167E-07, prediction.Change[3], precision: 5); // Martingale score*/
416+
Assert.Equal(5, index);
436417
}
437418

438419
[Fact]

0 commit comments

Comments
 (0)