Skip to content

Commit 406c664

Browse files
committed
Add test and sample.
1 parent 5f07695 commit 406c664

File tree

6 files changed

+193
-37
lines changed

6 files changed

+193
-37
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectAnomalyBySrCnn.cs

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,25 @@ public static void Example()
4545
string inputColumnName = nameof(TimeSeriesData.Value);
4646

4747
// The transformed model.
48-
//ITransformer model = ml.Transforms.DetectIidSpike(outputColumnName, inputColumnName, 95, Size).Fit(dataView);
49-
ITransformer model = ml.Transforms.DetectAnomalyBySrCnn(outputColumnName, inputColumnName, 8, 5, 5, 3, 6, 0.3).Fit(dataView);
48+
ITransformer model = ml.Transforms.DetectAnomalyBySrCnn(outputColumnName, inputColumnName, 64, 5, 5, 3, 21, 0.25).Fit(dataView);
5049

5150
// Create a time series prediction engine from the model.
5251
var engine = model.CreateTimeSeriesPredictionFunction<TimeSeriesData, SrCnnAnomalyDetection>(ml);
5352

5453
Console.WriteLine($"{outputColumnName} column obtained post-transformation.");
5554

56-
5755
// Create non-anomalous data and check for anomaly.
58-
for (int index = 0; index < 5; index++)
56+
for (int index = 0; index < 100; index++)
5957
{
6058
// Anomaly spike detection.
6159
PrintPrediction(5, engine.Predict(new TimeSeriesData(5)));
6260
}
6361

64-
// 5 0 5.00 0.50
65-
// 5 0 5.00 0.50
66-
// 5 0 5.00 0.50
67-
// 5 0 5.00 0.50
68-
// 5 0 5.00 0.50
69-
7062
// Spike.
71-
PrintPrediction(10, engine.Predict(new TimeSeriesData(10)));
72-
73-
// 10 1 10.00 0.00 <-- alert is on, predicted spike (check-point model)
63+
for (int index = 0; index < 5; index++)
64+
{
65+
PrintPrediction(15, engine.Predict(new TimeSeriesData(10)));
66+
}
7467

7568
// Checkpoint the model.
7669
var modelPath = "temp.zip";
@@ -85,13 +78,6 @@ public static void Example()
8578
// Anomaly spike detection.
8679
PrintPrediction(5, engine.Predict(new TimeSeriesData(5)));
8780
}
88-
89-
// 5 0 5.00 0.26 <-- load model from disk.
90-
// 5 0 5.00 0.26
91-
// 5 0 5.00 0.50
92-
// 5 0 5.00 0.50
93-
// 5 0 5.00 0.50
94-
9581
}
9682

9783
private static void PrintPrediction(float value, SrCnnAnomalyDetection prediction) =>
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using Microsoft.ML;
4+
using Microsoft.ML.Data;
5+
6+
namespace Samples.Dynamic
7+
{
8+
public static class DetectAnomalyBySrCnnBatchPrediction
9+
{
10+
public static void Example()
11+
{
12+
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
13+
// as well as the source of randomness.
14+
var ml = new MLContext();
15+
16+
// Generate sample series data with a spike
17+
var data = new List<TimeSeriesData>();
18+
for (int index = 0; index < 100; index++)
19+
{
20+
data.Add(new TimeSeriesData(5));
21+
}
22+
for (int index = 0; index < 5; index++)
23+
{
24+
data.Add(new TimeSeriesData(15));
25+
}
26+
for (int index = 0; index < 5; index++)
27+
{
28+
data.Add(new TimeSeriesData(5));
29+
}
30+
31+
// Convert data to IDataView.
32+
var dataView = ml.Data.LoadFromEnumerable(data);
33+
34+
// Setup the estimator arguments
35+
string outputColumnName = nameof(SrCnnAnomalyDetection.Prediction);
36+
string inputColumnName = nameof(TimeSeriesData.Value);
37+
38+
// The transformed data.
39+
var transformedData = ml.Transforms.DetectAnomalyBySrCnn(outputColumnName, inputColumnName, 64, 5, 5, 3, 21, 0.25).Fit(dataView).Transform(dataView);
40+
41+
// Getting the data of the newly created column as an IEnumerable of SrCnnAnomalyDetection.
42+
var predictionColumn = ml.Data.CreateEnumerable<SrCnnAnomalyDetection>(transformedData, reuseRowObject: false);
43+
44+
Console.WriteLine($"{outputColumnName} column obtained post-transformation.");
45+
Console.WriteLine("Data\tAlert\tScore\tP-Value");
46+
47+
int k = 0;
48+
foreach (var prediction in predictionColumn)
49+
PrintPrediction(data[k++].Value, prediction);
50+
51+
}
52+
53+
private static void PrintPrediction(float value, SrCnnAnomalyDetection prediction) =>
54+
Console.WriteLine("{0}\t{1}\t{2:0.00}\t{3:0.00}", value, prediction.Prediction[0],
55+
prediction.Prediction[1], prediction.Prediction[2]);
56+
57+
class TimeSeriesData
58+
{
59+
public float Value;
60+
61+
public TimeSeriesData(float value)
62+
{
63+
Value = value;
64+
}
65+
}
66+
67+
class SrCnnAnomalyDetection
68+
{
69+
[VectorType(3)]
70+
public double[] Prediction { get; set; }
71+
}
72+
}
73+
}

src/Microsoft.ML.TimeSeries/SRCNNAnomalyDetector.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ internal SrCnnAnomalyDetector(IHostEnvironment env, Options options)
152152
internal SrCnnAnomalyDetector(IHostEnvironment env, ModelLoadContext ctx)
153153
: base(env, ctx, LoaderSignature)
154154
{
155-
//TODO: Some data check here
156155
}
157156

158157
private SrCnnAnomalyDetector(IHostEnvironment env, SrCnnAnomalyDetector transform)

src/Microsoft.ML.TimeSeries/SrCnnAnomalyDetectionBase.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public SrCnnAnomalyDetectionBase(SrCnnArgumentBase args, string name, IHostEnvir
104104
public SrCnnAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name, SrCnnAnomalyDetectionBaseWrapper parent)
105105
: base(env, ctx, name)
106106
{
107-
Host.CheckDecode(InitialWindowSize == 0);
107+
//Host.CheckDecode(InitialWindowSize == 0);
108108
StateRef = new State(ctx.Reader);
109109
StateRef.InitState(this, Host);
110110
Parent = parent;
@@ -132,7 +132,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
132132
internal void SaveThis(ModelSaveContext ctx)
133133
{
134134
ctx.CheckAtModel();
135-
Host.Assert(InitialWindowSize == 0);
135+
//Host.Assert(InitialWindowSize == 0);
136136
base.SaveModel(ctx);
137137

138138
// *** Binary format ***
@@ -147,7 +147,7 @@ public State()
147147
{
148148
}
149149

150-
internal State(BinaryReader reader)
150+
internal State(BinaryReader reader) : base(reader)
151151
{
152152
WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
153153
InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
@@ -229,17 +229,17 @@ private protected override sealed void SpectralResidual(Single input, FixedSizeQ
229229
}
230230
List<Single> filteredIfftMagList = AverageFilter(ifftMagList, Parent.JudgementWindowSize);
231231

232-
// Step 7: Calculate score
232+
// Step 7: Calculate score and set result
233233
var score = CalculateSocre(ifftMagList[data.Count-1], filteredIfftMagList[data.Count-1]);
234-
score = (score < 1) ? 0 : score;
235-
score = (score > 10) ? 10 : score;
236234
score /= 10.0f;
237-
var detres = score > Parent.AlertThreshold ? 1 : 0;
238-
var mag = ifftMagList[data.Count-1];
235+
result.Values[1] = score;
239236

240-
//Step 8: Set result
237+
score = Math.Min(score, 1);
238+
score = Math.Max(score, 0);
239+
var detres = score > Parent.AlertThreshold ? 1 : 0;
241240
result.Values[0] = detres;
242-
result.Values[1] = score;
241+
242+
var mag = ifftMagList[data.Count-1];
243243
result.Values[2] = mag;
244244
}
245245

src/Microsoft.ML.TimeSeries/SrCnnTransformBase.cs

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,28 @@ private protected SrCnnTransformBase(int windowSize, int initialWindowSize, stri
8484
private protected SrCnnTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name)
8585
: base(Contracts.CheckRef(env, nameof(env)).Register(name), ctx)
8686
{
87-
//TODO: Read from binary format
87+
OutputLength = 3;
88+
89+
byte temp;
90+
temp = ctx.Reader.ReadByte();
91+
BackAddWindowSize = (int)temp;
92+
Host.CheckDecode(BackAddWindowSize > 0);
93+
94+
temp = ctx.Reader.ReadByte();
95+
LookaheadWindowSize = (int)temp;
96+
Host.CheckDecode(LookaheadWindowSize > 0);
97+
98+
temp = ctx.Reader.ReadByte();
99+
AvergingWindowSize = (int)temp;
100+
Host.CheckDecode(AvergingWindowSize > 0);
101+
102+
temp = ctx.Reader.ReadByte();
103+
JudgementWindowSize = (int)temp;
104+
Host.CheckDecode(JudgementWindowSize > 0);
105+
106+
temp = ctx.Reader.ReadByte();
107+
AlertThreshold = (double)temp;
108+
Host.CheckDecode(AlertThreshold >= 0 && AlertThreshold <= 1);
88109
}
89110

90111
private protected SrCnnTransformBase(SrCnnArgumentBase args, string name, IHostEnvironment env)
@@ -95,7 +116,23 @@ private protected SrCnnTransformBase(SrCnnArgumentBase args, string name, IHostE
95116

96117
private protected override void SaveModel(ModelSaveContext ctx)
97118
{
98-
//TODO: save to ctx and write to file
119+
Host.CheckValue(ctx, nameof(ctx));
120+
ctx.CheckAtModel();
121+
122+
Host.Assert(WindowSize > 0);
123+
Host.Assert(InitialWindowSize == WindowSize);
124+
Host.Assert(BackAddWindowSize > 0);
125+
Host.Assert(LookaheadWindowSize > 0);
126+
Host.Assert(AvergingWindowSize > 0);
127+
Host.Assert(JudgementWindowSize > 0);
128+
Host.Assert(AlertThreshold >= 0 && AlertThreshold <= 1);
129+
130+
base.SaveModel(ctx);
131+
ctx.Writer.Write((byte)BackAddWindowSize);
132+
ctx.Writer.Write((byte)LookaheadWindowSize);
133+
ctx.Writer.Write((byte)AvergingWindowSize);
134+
ctx.Writer.Write((byte)JudgementWindowSize);
135+
ctx.Writer.Write((byte)AlertThreshold);
99136
}
100137

101138
internal override IStatefulRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(Host, this, schema);
@@ -225,17 +262,17 @@ private protected SrCnnStateBase() { }
225262

226263
private protected override void CloneCore(TState state)
227264
{
228-
//TODO:
265+
base.CloneCore(state);
266+
Contracts.Assert(state is SrCnnStateBase);
229267
}
230268

231269
private protected SrCnnStateBase(BinaryReader reader) : base(reader)
232270
{
233-
//TODO:
234271
}
235272

236273
internal override void Save(BinaryWriter writer)
237274
{
238-
//TODO:
275+
base.Save(writer);
239276
}
240277

241278
private protected override void SetNaOutput(ref VBuffer<double> dst)
@@ -244,7 +281,7 @@ private protected override void SetNaOutput(ref VBuffer<double> dst)
244281
var editor = VBufferEditor.Create(ref dst, outputLength);
245282

246283
for (int i = 0; i < outputLength; ++i)
247-
editor.Values[i] = Double.NaN;
284+
editor.Values[i] = 0;
248285

249286
dst = editor.Commit();
250287
}

test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Generic;
67
using System.IO;
78
using Microsoft.ML.Data;
@@ -41,6 +42,22 @@ public Data(float value)
4142
}
4243
}
4344

45+
private sealed class TimeSeriesData
46+
{
47+
public float Value;
48+
49+
public TimeSeriesData(float value)
50+
{
51+
Value = value;
52+
}
53+
}
54+
55+
private sealed class SrCnnAnomalyDetection
56+
{
57+
[VectorType(3)]
58+
public double[] Prediction { get; set; }
59+
}
60+
4461
[Fact]
4562
public void ChangeDetection()
4663
{
@@ -276,5 +293,49 @@ public void ChangePointDetectionWithSeasonalityPredictionEngine()
276293
Assert.Equal(0.14823824685192111, prediction.Change[2], precision: 5); // P-Value score
277294
Assert.Equal(1.5292508189989167E-07, prediction.Change[3], precision: 5); // Martingale score
278295
}
296+
297+
[Fact]
298+
public void AnomalyDetectionWithSrCnn()
299+
{
300+
var ml = new MLContext();
301+
302+
// Generate sample series data with a spike
303+
var data = new List<TimeSeriesData>();
304+
for (int index = 0; index < 100; index++)
305+
{
306+
data.Add(new TimeSeriesData(5));
307+
}
308+
for (int index = 0; index < 5; index++)
309+
{
310+
data.Add(new TimeSeriesData(15));
311+
}
312+
for (int index = 0; index < 5; index++)
313+
{
314+
data.Add(new TimeSeriesData(5));
315+
}
316+
317+
// Convert data to IDataView.
318+
var dataView = ml.Data.LoadFromEnumerable(data);
319+
320+
// Setup the estimator arguments
321+
string outputColumnName = nameof(SrCnnAnomalyDetection.Prediction);
322+
string inputColumnName = nameof(TimeSeriesData.Value);
323+
324+
// The transformed data.
325+
var transformedData = ml.Transforms.DetectAnomalyBySrCnn(outputColumnName, inputColumnName, 64, 5, 5, 3, 21, 0.25).Fit(dataView).Transform(dataView);
326+
327+
// Getting the data of the newly created column as an IEnumerable of SrCnnAnomalyDetection.
328+
var predictionColumn = ml.Data.CreateEnumerable<SrCnnAnomalyDetection>(transformedData, reuseRowObject: false);
329+
330+
int k = 0;
331+
foreach (var prediction in predictionColumn)
332+
{
333+
if (k == 101 || k == 106)
334+
Assert.Equal(1, prediction.Prediction[0]);
335+
else
336+
Assert.Equal(0, prediction.Prediction[0]);
337+
k += 1;
338+
}
339+
}
279340
}
280341
}

0 commit comments

Comments
 (0)